Computer Science
HMT: Hierarchical Memory Transformer for Efficient Long Context Language Processing
Z. He, Y. Cao, et al.
The paper addresses the challenge of efficient long-context processing in decoder-only language models whose self-attention scales quadratically with sequence length, forcing restricted context windows in practice. Real-world tasks such as long-document understanding, book summarization, and lifelong question answering require handling extremely long or streaming inputs with frequent topic switches. Prior approaches (sparse attention, retrieval-augmented generation, recurrent/state-space models) either modify core architectures, struggle with relevance filtering, or do not scale memory efficiently. The research question is how to enable model-independent, plug-and-play long-context processing with adaptive selection of relevant past information, minimal memory overhead, and superior generation quality. The proposed HMT framework imitates human memory hierarchy (sensory, short-term, long-term) and introduces a memory retrieval mechanism to recall and strengthen relevant information while filtering distractions, aiming to improve perplexity, accuracy, and scalability across diverse backbones.
Long Context Transformers: Sparse attention approaches (e.g., sliding window, Longformer, Poolingformer) reduce computation but can miss long-range dependencies and still increase memory use with input length. Retrieval-augmented models (e.g., Unlimiformer) select top-K relevant tokens to prune attention but may accumulate losses from omitted tokens and require storing extensive encodings. Segment-level recurrence (e.g., Compressive Transformer, RMT) summarizes and propagates segment states to handle unlimited sequences, but repeated summarization dilutes information and can let irrelevant content dominate memory. Recurrent Sequence Models: RNN variants (LSTM/GRU) are memory-efficient but underperform self-attention on contextual relationships. Modern recurrent/state-space models (RWKV, Mamba) compress and pass states efficiently with linear-time operations, achieving strong memorization but struggle with context relationship modeling and filtering. Hybrid Transformer-Mamba designs reintroduce transformer scaling challenges. Problem Formulation: Adaptive long-context processing is crucial under hardware memory constraints and streaming input settings. Models must select and recall relevant past context while filtering irrelevant information, maintaining effectiveness despite frequent context switches and not requiring full storage of all past tokens. HMT is proposed to meet these requirements without altering backbone architectures.
HMT is a model-independent, plug-and-play framework that augments decoder-only backbones with hierarchical memory and a memory retrieval mechanism operating at the segment level. Inputs are chunked into L-token segments, and HMT runs four steps per segment n on hidden embeddings H_n from the backbone token embedding layer:
- Representation encoding: Extract the ongoing topic by summarizing the first j embeddings of the current segment H_n, augmented with a learnable segment summarization prompt embedding T (soft prompt tuning). The backbone model (BBM) processes [T || H_n[0:j) || T] and outputs a single summary embedding S_n that represents the segment’s topic.
- Memory search: Maintain a sliding window of N cached memory embeddings M_{(n−N+1:n)} produced from previous segments. Use S_n as a query and perform cross-attention-like retrieval (projections W_q, W_k) to compute a memorization prompt embedding P_n = softmax((S_n W_q) (M W_k)^T / sqrt(d_h)) M, directly weighting past memory embeddings by normalized similarity without value/output projections, producing a single fused prompt embedding relevant to the current segment.
- Prepending sensory memory: Augment current segment embeddings by prepending the last k embeddings from the previous segment (sensory memory H_{n−1}[L−k:L)) and also prepend/append the memorization prompt embedding P_n to guide compression with positional awareness.
- Decoding and summarization: The backbone processes the augmented sequence (P_n || H_{n−1}[L−k:L) || H_n || P_n) to produce hidden embeddings H_n^{out} used for logits and a memory embedding M_n that summarizes the augmented segment. Cache M_n into long-term memory for future retrieval. Hierarchical Memorization:
- Sensory memory: local continuity via last k embeddings of the prior segment.
- Short-term memory: per-segment summarization into M_n informed by P_n and segment content.
- Long-term memory: cache of recent memory embeddings (size N) used by retrieval to rescale and recall relevant history. Training and Efficiency: Representation encoding runs in parallel with inference; memory search is O(N) and parallelizable for small N (e.g., N=300), adding negligible runtime overhead. Typical hyperparameters: L=1024, j=512, N=300, k=32, adjusted per backbone. HMT parameters are jointly trained with backbone via multi-stage BPTT: Stage 1 trains without retrieval (2 segments), saves checkpoint; Stage 2 enables retrieval and trains with maximum feasible unroll depth (up to 15 segments) using memory optimization (CPU offload of intermediates, ZeRO stage 2) and, for larger models, LoRA. HMT maintains constant peak memory for inference regardless of input length, keeping a fixed KV cache and a bounded set of memory embeddings.
General language modeling:
- Consistent perplexity reductions across backbones on Wikitext-103 and PG-19 for inputs spanning 2k–100k tokens.
- Average test PPL (Wikitext-103): OPT 350M 15.11 → HMT+OPT 350M 14.28 (-5.8%); OPT 2.7B 12.12 → HMT+OPT 2.7B 8.61 (-28.9%); RWKV 430M 19.33 → HMT+RWKV 430M 16.10 (-16.6%); RWKV 3B 13.30 → HMT+RWKV 3B 9.93 (-25.3%).
- Qwen 2.5 14B on PG-19: HMT improves effectiveness by ~10%, whereas RMT reduces it. Question answering (PubMedQA, multi-context):
- Long answers: PPL improved by 9.48% for samples with 2–10 contexts.
- Short answers: +1.0% accuracy overall; HMT shows larger gains with more contexts. Comparison to RMT and other memory methods:
- Wikitext-103: HMT outperforms RMT by 13.0% (OPT) and 10.8% (OpenLlamaV2); for RWKV, HMT boosts effectiveness by 16.5% while RMT degrades it.
- PG-19: HMT exceeds RMT by 3.98% (OPT) and 6.85% (OpenLlamaV2); +9.96% for RWKV.
- Versus Memorizing Transformer, LongMem, CCM-concat: HMT achieves lower PPL with similar/smaller backbones and maintains lowest memory complexity (constant peak memory with respect to input length).
- Versus HOMER (+YaRN): HMT attains 9.9% lower perplexity on PG-19; benefits grow with input length. LongBench subsets and parameter/memory efficiency:
- HMT on small models achieves comparable or better metrics than large long-context LLMs with 2–57x fewer parameters and 2.5–116× lower inference memory.
- Example: Yi-6B-200K on MI210 cannot handle 30k tokens; sliding window (5.2k) consumes 44.8 GB VRAM. HMT+Yi-6B-200K processes 30k tokens with 33.9 GB VRAM (512-token segments) and +2% effectiveness.
- Wikitext-103 (30k tokens): HMT+OpenLlamaV2 3B PPL 7.04; HMT+Mistral-7B 5.12; HMT+RWKV 3B 10.94; HMT+Mamba 370M 16.71; baseline Mistral-7B 5.47, Yi-6B-SW-5.2K 6.89, RWKV 3B 13.13, Mamba 370M 87.08. Comparison to LongMem (ArXiv subset of The Pile):
- LongMem (558M, 1k segment, 65k in-memory) PPL 10.08 vs. HMT+Qwen1.5-0.5B (463M, 1k segment, N=300) PPL 9.02 ± 0.04, indicating better effectiveness with fewer parameters and shorter memory. Ablations and training behavior:
- Memory retrieval is essential for scalability; HMT without retrieval underperforms, and gains grow with input length.
- Summarizing half vs. whole segment for representation encoding yields negligible difference, enabling faster inference.
- Increasing cached memory embeddings improves effectiveness with diminishing returns beyond N≈300.
- HMT exhibits improved effectiveness with higher BPTT unroll depth (e.g., 2→5→15 segments: Wikitext-103 PPL 9.36→9.15→8.20), unlike RMT which suffers gradient issues.
HMT directly addresses the need for adaptive long-context processing by combining hierarchical memory with a retrieval mechanism that identifies and strengthens relevant historical information while suppressing irrelevant context. The empirical reductions in perplexity across multiple datasets and architectures demonstrate improved long-range coherence and reasoning. Improvements in multi-context QA accuracy and reduced memory consumption validate that HMT filters and recalls the right information efficiently, enabling smaller models to rival or surpass larger long-context LLMs. Unlike prior methods that require architectural changes or store large KV caches, HMT maintains constant peak inference memory by condensing segments into fixed-size embeddings and retrieving via cross-attention, making it scalable and deployable across future decoder-only backbones. Its gradient stability with higher BPTT unroll depth further supports robustness for learning extended dependencies, aligning with the goal of lifelong and streaming language tasks.
The paper introduces HMT, a plug-and-play framework that augments decoder-only LLMs with hierarchical memory (sensory, short-term, long-term) and a cross-attention-based memory retrieval mechanism. HMT consistently improves generation quality for long inputs across diverse backbones, achieves comparable or superior results to specialized long-context models with far fewer parameters and significantly lower inference memory, and outperforms recent memory-augmented/hierarchical baselines. It enables efficient, adaptive processing of infinitely long streams without altering core backbone architectures. Future directions include intelligent memory prefetching to handle larger N across memory hierarchies, more efficient training to extend BPTT depth, and exploring multi-level long-term memory structures to further improve information access.
- Retrieval window size N: While small N (e.g., 300) keeps overhead negligible and suffices for ~100k-token inputs, larger N and heterogeneous memory hierarchies can introduce latency; intelligent prefetching is needed.
- Training memory demands: Multi-segment BPTT for tuning HMT parameters can be memory-intensive, limiting larger-scale experiments; more efficient methods to extend BPTT depth are desirable.
- Single-level long-term memory: Current design uses one level; multi-level long-term memory may further improve access efficiency and scalability.
Related Publications
Explore these studies to deepen your understanding of the subject.

