logo
Loading...
M+: Extending MemoryLLM with Scalable Long-Term Memory

Computer Science

M+: Extending MemoryLLM with Scalable Long-Term Memory

Y. Wang, D. Krotov, et al.

Large language models often lose information from the distant past. MemoryLLM compresses past context into a 1B-parameter latent memory but struggles beyond ~20k tokens. This paper presents M+, which augments MemoryLLM with a long-term memory and a co-trained retriever to dynamically fetch relevant information during generation, extending retention from under 20k to over 160k tokens with similar GPU overhead. Research conducted by Yu Wang, Dmitry Krotov, Yuanzhe Hu, Yifan Gao, Wangchunshu Zhou, Julian McAuley, Dan Gutfreund, Rogerio Feris, and Zexue He.... show more
Introduction

The paper addresses the challenge of long-term information retention in large language models. Memory modules for LLMs fall into token-level memory (structured text, summaries, knowledge graphs, databases) and latent-space memory (compressed hidden-state vectors or parameters). While token-level memory offers modularity and interpretability, it can be redundant and difficult to resolve conflicts, and may not capture integrated representations akin to human memory. Latent-space memory supports efficient compression, end-to-end training, and aligns with integrated neural representations. MemoryLLM compresses information into hidden states with a large memory pool across layers, but struggles to recall information injected beyond ~20k tokens. The study proposes M+, which augments MemoryLLM with a scalable long-term memory stored on CPU and a co-trained retriever to selectively fetch relevant memory tokens per layer during generation. The goal is to extend long-context understanding and knowledge retention substantially without increasing GPU memory use.

Literature Review

Related work is categorized into token-level memory and latent-space memory. Token-level memory approaches include MemoryBank, RecurrentGPT, MemGPT, ChatDB, and MemLLM, which represent memory as text, summaries, knowledge graphs, or databases, retrieved via embeddings or model-generated queries. These methods are modular and interpretable but can be redundant and face challenges with conflict resolution and representing complex conversational structures. Latent-space memory methods store information in hidden states, parameters, or external matrices, using memory slots or key-value caches, sometimes with forgetting mechanisms (e.g., Camel-LoT, Memoria). MemoryLLM compresses knowledge into hidden states with random dropping to control memory growth. Memory³ stores a large pretraining dataset in hidden-state space; Larimar introduces a read/write memory matrix for knowledge editing; SELF-PARAM embeds knowledge directly into model parameters. Despite progress, latent-space memory methods often fall short with extremely long inputs. M+ is positioned to improve scalability and retention beyond these limitations.

Methodology

M+ builds on MemoryLLM's architecture comprising a transformer decoder φ and a per-layer memory pool θ with N tokens of dimension d. In MemoryLLM, updates extract the last K tokens from θ_l, combine with a new chunk to produce new K tokens, randomly drop K tokens from θ_l, and append the new K tokens; generation uses cross-attention to θ_l. M+ introduces a long-term memory Θ with flexible per-layer pools and a maximum capacity M (set to 150k tokens). During updates, the K tokens dropped from short-term memory θ are stored in Θ and assigned an age for chronological ordering; when Θ reaches M, the oldest tokens are dropped. During generation, at each layer, K_Θ tokens are retrieved from Θ_l via a trained retriever, sorted by age, and concatenated with θ_l for cross-attention, allowing the model to access both short- and long-term memory. A Multi-LoRA design uses separate LoRA weights for update (write/compress) and generation (read/load) phases. The retriever consists of two MLP projectors: f_q for queries and f_k for keys with output dimension d_proj = d/20. When tokens are moved into Θ, f_k produces compact key vectors stored alongside memory tokens. At generation, f_q maps query hidden states to query vectors; retrieval is based on dot products between queries and keys. Retriever training: a document x is split into chunks x_1..x_n, x_1..x_{n-1} are injected; memory tokens related to these are θ_+, others are θ_-. A forward pass on x_n yields hidden states h_n. The objective per layer minimizes -log(p_+) - log(1 - p_-), where p_+ = ⟨f_q(h_n), f_k(θ_+)⟩ and p_- = ⟨f_q(h_n), f_k(θ_-)⟩, increasing similarity to relevant tokens and decreasing similarity to irrelevant ones. Training details: built on Llama-3.1-8B, trained with eight A100 GPUs using deepspeed-stage-2. Key hyperparameters: K=256; Stage 1/2 short-term memory N=12,800 tokens per layer; Stage 3 sets θ_l=10,240 and retrieves K_Θ=2,560 from Θ to maintain 12,800 total memory tokens. Generation window is 2,048 tokens, resulting in cross-attention of up to (12,800 + 2,048) × 2,048 per layer. Data curriculum: Stage 1 continual training on fineweb-edu for 1.2M steps (~4 weeks) with three subtasks (two-chunk, multi-chunk, revisiting cached chunks); Stage 2 long-context modeling with SlimPajama documents 4k–64k, sampling 200k examples per length range and mixing with fineweb snapshots (equal proportions), trained for one epoch (~1 week); Stage 3 introduces long-term memory and continues training on newly sampled long documents, adjusting memory composition (10,240 short-term + 2,560 retrieved long-term tokens). The training subtasks maintain average revisit distances ~60 in Stages 1–2 and ~200 in Stage 3.

Key Findings

• M+ significantly improves long-context understanding and knowledge retention compared to MemoryLLM and strong baselines. It extends retention from under ~20k tokens to over 160k tokens while keeping GPU memory overhead similar. • LongBook-QA and LongBook Event QA: M+ achieves superior QA-F1 and accuracy across both benchmarks, outperforming Llama-3.1-8B-16k, Llama-3.1-3B-128k, and Llama-3.1-8B-SnapKV, despite processing fewer tokens (12,800 memory + 2,048 generation window). • GPU memory efficiency: Table 1 shows inference GPU memory (MB): Llama-3.1-8B-SnapKV 32,574.49; Llama-3.2-3B-128k 30,422.70; M+ 21,177.76; Llama-3.1-8B-16k 19,239.21; M+ (offload) 17,973.34. Offloading memory tokens to CPU yields the lowest GPU usage with minimal latency overhead. • Knowledge retention on SQuAD and NaturalQA: M+ outperforms MemoryLLM-7B and Llama-3.1-8B-SnapKV. SnapKV struggles to recall information injected >30k tokens earlier even with a 48k prompt. M+ retrieves ~30% of ground-truth long-term tokens versus ~3% expected by random retrieval (2,560/81,276). • On relatively short-document tasks (LongBench at 8k/16k), M+ matches Llama-3.1-8B on 4 of 6 datasets but underperforms on hotpotqa and musique. The paper attributes this to random dropping (information loss) and limited cross-chunk attention during chunk compression. • Latency analysis (single H100, up to 128k inputs): M+ incurs additional retrieval-induced latency compared to MemoryLLM-8B; offloading introduces slight extra latency that becomes negligible for very long sequences (≈1 s extra at 128k, ~3% overhead). • Ablations: Progressive improvements across training stages (lowest validation loss for M+ on 32k–64k subset). Long-term memory substantially enhances retention (from ~50k to >160k tokens). A trained retriever (M+) outperforms attention-based retrieval (M+-Attn) on SQuAD and NaturalQA.

Discussion

M+ directly tackles the limitation of MemoryLLM's long-term recall by preserving dropped short-term memory tokens in a scalable long-term memory per layer and learning a compact retriever for efficient, layer-wise retrieval. Storing long-term memory on CPU and retrieving a small subset per layer reduces GPU memory costs while maintaining access to distant information through cross-attention. The co-trained retriever improves retrieval quality over attention-based heuristics, enabling stronger reasoning over long dependencies in tasks like LongBook-QA and event-based QA. Although some performance trade-offs appear on shorter inputs due to random dropping and constrained cross-chunk attention, the approach achieves linear scaling in computation and offers substantial gains in extreme long-context scenarios, demonstrating a practical path to extending LLM context capabilities under fixed GPU budgets.

Conclusion

The paper introduces M+, a memory-augmented LLM that extends MemoryLLM with a scalable long-term memory and a co-trained retriever. M+ effectively retains and retrieves information across very long inputs, achieving superior performance on long-context understanding and knowledge retention benchmarks while remaining GPU-memory efficient. Key contributions include the long-term memory mechanism, efficient per-layer retrieval, a Multi-LoRA design, and a staged training curriculum emphasizing long-context modeling. Future work will focus on reducing CPU–GPU communication overhead to further improve generation efficiency.

Limitations

• Performance on some short-document benchmarks (e.g., hotpotqa, musique) is slightly lower than the baseline, attributed to random dropping in memory compression and limited cross-chunk attention during chunk processing. • Retrieval introduces additional latency compared to MemoryLLM; offloading memory to CPU adds I/O overhead, though it diminishes for longer sequences. • Training was constrained by resources (deepspeed-stage-2, eight A100s), limiting context scaling during training compared to fully 128k-trained models. • The long-term memory capacity is capped (M=150k), requiring age-based eviction; choosing hyperparameters (K, N, K_Θ) and eviction policies may affect performance trade-offs. • Construction of the LongBook Event QA dataset involves GPT-4o-generated distractors, which may introduce biases or artifacts in evaluation. • As with memory-augmented systems, risks include potential retention of sensitive or biased information over extended spans; safeguards are necessary for safe deployment.

Listen, Learn & Level Up
Over 10,000 hours of research content in 25+ fields, available in 22+ languages.
No more digging through PDFs, just hit play and absorb the world's latest research in your language, on your time.
listen to research audio papers with researchbunny