Parallel prefetching: The secret sauce for faster LLM inference.
PRESERVE optimizes LLM inference by cleverly prefetching model weights and KV-cache data during communication operations, reducing memory bottlenecks and communication overhead.
-----
https://arxiv.org/abs/2501.08192
Original Problem 😕:
LLM inference faces two critical challenges - HBM bandwidth bottlenecks limiting single-device performance, and significant communication overhead in distributed settings. The combined effect severely impacts inference speed and cost-efficiency.
-----
Solution in this Paper 🔧:
→ PRESERVE introduces a novel prefetching framework that overlaps memory reads with collective communication operations
→ The framework automatically inserts prefetch operators into computational graphs without requiring code modifications
→ It intelligently manages L2 cache utilization to prevent cache pollution
→ The system uses parallel streams to synchronize prefetching and communication operations
→ A graph optimization algorithm dynamically tracks prefetched data and estimates L2 cache usage during compilation
-----
Key Insights 💡:
→ Optimal L2 cache size increases from 8MB to 104MB when prefetching is considered
→ Maximum speedup occurs when allreduce and prefetch latencies match
→ Network bandwidth significantly impacts prefetching effectiveness
→ Cluster size influences the balance between prefetch and communication latencies
-----
Results 📊:
→ Up to 1.6x end-to-end speedup on state-of-the-art LLMs
→ 1.25x improvement in performance per cost with optimal L2 cache configuration
→ Demonstrated scalability across various batch sizes and sequence lengths
Share this post