Computer Science
M+: Extending MemoryLLM with Scalable Long-Term Memory
Y. Wang, D. Krotov, et al.
The paper addresses the challenge of long-term information retention in large language models. Memory modules for LLMs fall into token-level memory (structured text, summaries, knowledge graphs, databases) and latent-space memory (compressed hidden-state vectors or parameters). While token-level memory offers modularity and interpretability, it can be redundant and difficult to resolve conflicts, and may not capture integrated representations akin to human memory. Latent-space memory supports efficient compression, end-to-end training, and aligns with integrated neural representations. MemoryLLM compresses information into hidden states with a large memory pool across layers, but struggles to recall information injected beyond ~20k tokens. The study proposes M+, which augments MemoryLLM with a scalable long-term memory stored on CPU and a co-trained retriever to selectively fetch relevant memory tokens per layer during generation. The goal is to extend long-context understanding and knowledge retention substantially without increasing GPU memory use.
Related work is categorized into token-level memory and latent-space memory. Token-level memory approaches include MemoryBank, RecurrentGPT, MemGPT, ChatDB, and MemLLM, which represent memory as text, summaries, knowledge graphs, or databases, retrieved via embeddings or model-generated queries. These methods are modular and interpretable but can be redundant and face challenges with conflict resolution and representing complex conversational structures. Latent-space memory methods store information in hidden states, parameters, or external matrices, using memory slots or key-value caches, sometimes with forgetting mechanisms (e.g., Camel-LoT, Memoria). MemoryLLM compresses knowledge into hidden states with random dropping to control memory growth. Memory³ stores a large pretraining dataset in hidden-state space; Larimar introduces a read/write memory matrix for knowledge editing; SELF-PARAM embeds knowledge directly into model parameters. Despite progress, latent-space memory methods often fall short with extremely long inputs. M+ is positioned to improve scalability and retention beyond these limitations.
M+ builds on MemoryLLM's architecture comprising a transformer decoder φ and a per-layer memory pool θ with N tokens of dimension d. In MemoryLLM, updates extract the last K tokens from θ_l, combine with a new chunk to produce new K tokens, randomly drop K tokens from θ_l, and append the new K tokens; generation uses cross-attention to θ_l. M+ introduces a long-term memory Θ with flexible per-layer pools and a maximum capacity M (set to 150k tokens). During updates, the K tokens dropped from short-term memory θ are stored in Θ and assigned an age for chronological ordering; when Θ reaches M, the oldest tokens are dropped. During generation, at each layer, K_Θ tokens are retrieved from Θ_l via a trained retriever, sorted by age, and concatenated with θ_l for cross-attention, allowing the model to access both short- and long-term memory. A Multi-LoRA design uses separate LoRA weights for update (write/compress) and generation (read/load) phases. The retriever consists of two MLP projectors: f_q for queries and f_k for keys with output dimension d_proj = d/20. When tokens are moved into Θ, f_k produces compact key vectors stored alongside memory tokens. At generation, f_q maps query hidden states to query vectors; retrieval is based on dot products between queries and keys. Retriever training: a document x is split into chunks x_1..x_n, x_1..x_{n-1} are injected; memory tokens related to these are θ_+, others are θ_-. A forward pass on x_n yields hidden states h_n. The objective per layer minimizes -log(p_+) - log(1 - p_-), where p_+ = ⟨f_q(h_n), f_k(θ_+)⟩ and p_- = ⟨f_q(h_n), f_k(θ_-)⟩, increasing similarity to relevant tokens and decreasing similarity to irrelevant ones. Training details: built on Llama-3.1-8B, trained with eight A100 GPUs using deepspeed-stage-2. Key hyperparameters: K=256; Stage 1/2 short-term memory N=12,800 tokens per layer; Stage 3 sets θ_l=10,240 and retrieves K_Θ=2,560 from Θ to maintain 12,800 total memory tokens. Generation window is 2,048 tokens, resulting in cross-attention of up to (12,800 + 2,048) × 2,048 per layer. Data curriculum: Stage 1 continual training on fineweb-edu for 1.2M steps (~4 weeks) with three subtasks (two-chunk, multi-chunk, revisiting cached chunks); Stage 2 long-context modeling with SlimPajama documents 4k–64k, sampling 200k examples per length range and mixing with fineweb snapshots (equal proportions), trained for one epoch (~1 week); Stage 3 introduces long-term memory and continues training on newly sampled long documents, adjusting memory composition (10,240 short-term + 2,560 retrieved long-term tokens). The training subtasks maintain average revisit distances ~60 in Stages 1–2 and ~200 in Stage 3.
• M+ significantly improves long-context understanding and knowledge retention compared to MemoryLLM and strong baselines. It extends retention from under ~20k tokens to over 160k tokens while keeping GPU memory overhead similar. • LongBook-QA and LongBook Event QA: M+ achieves superior QA-F1 and accuracy across both benchmarks, outperforming Llama-3.1-8B-16k, Llama-3.1-3B-128k, and Llama-3.1-8B-SnapKV, despite processing fewer tokens (12,800 memory + 2,048 generation window). • GPU memory efficiency: Table 1 shows inference GPU memory (MB): Llama-3.1-8B-SnapKV 32,574.49; Llama-3.2-3B-128k 30,422.70; M+ 21,177.76; Llama-3.1-8B-16k 19,239.21; M+ (offload) 17,973.34. Offloading memory tokens to CPU yields the lowest GPU usage with minimal latency overhead. • Knowledge retention on SQuAD and NaturalQA: M+ outperforms MemoryLLM-7B and Llama-3.1-8B-SnapKV. SnapKV struggles to recall information injected >30k tokens earlier even with a 48k prompt. M+ retrieves ~30% of ground-truth long-term tokens versus ~3% expected by random retrieval (2,560/81,276). • On relatively short-document tasks (LongBench at 8k/16k), M+ matches Llama-3.1-8B on 4 of 6 datasets but underperforms on hotpotqa and musique. The paper attributes this to random dropping (information loss) and limited cross-chunk attention during chunk compression. • Latency analysis (single H100, up to 128k inputs): M+ incurs additional retrieval-induced latency compared to MemoryLLM-8B; offloading introduces slight extra latency that becomes negligible for very long sequences (≈1 s extra at 128k, ~3% overhead). • Ablations: Progressive improvements across training stages (lowest validation loss for M+ on 32k–64k subset). Long-term memory substantially enhances retention (from ~50k to >160k tokens). A trained retriever (M+) outperforms attention-based retrieval (M+-Attn) on SQuAD and NaturalQA.
M+ directly tackles the limitation of MemoryLLM's long-term recall by preserving dropped short-term memory tokens in a scalable long-term memory per layer and learning a compact retriever for efficient, layer-wise retrieval. Storing long-term memory on CPU and retrieving a small subset per layer reduces GPU memory costs while maintaining access to distant information through cross-attention. The co-trained retriever improves retrieval quality over attention-based heuristics, enabling stronger reasoning over long dependencies in tasks like LongBook-QA and event-based QA. Although some performance trade-offs appear on shorter inputs due to random dropping and constrained cross-chunk attention, the approach achieves linear scaling in computation and offers substantial gains in extreme long-context scenarios, demonstrating a practical path to extending LLM context capabilities under fixed GPU budgets.
The paper introduces M+, a memory-augmented LLM that extends MemoryLLM with a scalable long-term memory and a co-trained retriever. M+ effectively retains and retrieves information across very long inputs, achieving superior performance on long-context understanding and knowledge retention benchmarks while remaining GPU-memory efficient. Key contributions include the long-term memory mechanism, efficient per-layer retrieval, a Multi-LoRA design, and a staged training curriculum emphasizing long-context modeling. Future work will focus on reducing CPU–GPU communication overhead to further improve generation efficiency.
• Performance on some short-document benchmarks (e.g., hotpotqa, musique) is slightly lower than the baseline, attributed to random dropping in memory compression and limited cross-chunk attention during chunk processing. • Retrieval introduces additional latency compared to MemoryLLM; offloading memory to CPU adds I/O overhead, though it diminishes for longer sequences. • Training was constrained by resources (deepspeed-stage-2, eight A100s), limiting context scaling during training compared to fully 128k-trained models. • The long-term memory capacity is capped (M=150k), requiring age-based eviction; choosing hyperparameters (K, N, K_Θ) and eviction policies may affect performance trade-offs. • Construction of the LongBook Event QA dataset involves GPT-4o-generated distractors, which may introduce biases or artifacts in evaluation. • As with memory-augmented systems, risks include potential retention of sensitive or biased information over extended spans; safeguards are necessary for safe deployment.
Related Publications
Explore these studies to deepen your understanding of the subject.

