"Exploiting Sparsity for Long Context Inference: Million Token Contexts on Commodity GPUs"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2502.06766
LLMs face challenges in processing very long input texts due to high computational and memory demands, especially on typical hardware. This paper addresses the issue of inefficient inference when using long contexts in transformer models.
This paper introduces a method to reduce inference costs. It attends only to the most important tokens at each step using a top-k selection, enhancing efficiency for long contexts.
-----
📌 By offloading key/value cache to CPU and using top-k selection, this method drastically reduces GPU memory needs. It enables million-token context inference on commodity GPUs.
📌 This work empirically validates inherent attention sparsity in LLMs. Top-k attention effectively exploits this, focusing on crucial tokens while maintaining performance with minimal overhead.
📌 Architecturally, top-k attention smartly shifts key/value cache to CPU, using approximate nearest neighbor search. This decouples context length scaling from GPU memory limitations for practical deployment.
----------
Methods Explored in this Paper 🔧:
→ The paper proposes a top-k attention mechanism.
→ It selects only the most relevant key-value pairs for attention computation at each layer.
→ Key and value vectors are stored in CPU memory within a vector database.
→ Approximate k-Nearest Neighbor search is used to retrieve top-k keys from the CPU cache for each query during decoding.
→ The retrieved keys and values are moved to GPU for attention computation.
→ This method reduces computation and memory overhead by focusing on crucial tokens.
→ The value of k can be adjusted per layer to optimize performance and efficiency.
-----
Key Insights 💡:
→ Modern LLMs exhibit sparse attention patterns.
→ Only a small fraction of tokens significantly contribute to the attention mechanism.
→ Top-k attention effectively exploits this sparsity.
→ Models can maintain performance even when attending to a very small percentage of input tokens.
→ This sparsity is observed across different model sizes, architectures and training types.
-----
Results 📊:
→ Achieves over 95% of full attention performance on benchmarks like LM-Eval, AlpacaEval, and RULER, by attending to less than 2% of input tokens.
→ On RULER benchmark, with k=2, achieves over 60% performance across context lengths from 8k to 131k tokens.
→ For Needle In A Haystack task with 1 million tokens, k=1 is sufficient.