Adjoint sharding lets LLMs train on million-token sequences without exploding GPU memory.
The paper introduces adjoint sharding - a novel technique that drastically reduces memory usage for training LLMs on long sequences by optimizing gradient calculations.
-----
https://arxiv.org/abs/2501.00692
Original Problem 🤔:
→ Training LLMs on very long contexts (>1M tokens) hits severe GPU memory limitations and slow training speeds.
→ Current methods force training with short contexts (few thousand tokens) and use workarounds during inference.
-----
Solution in this Paper 💡:
→ Adjoint sharding splits gradient computation during training into independent vector-Jacobian products.
→ It uses the adjoint method instead of backpropagation to compute equivalent gradients while using much less memory.
→ The method introduces truncated adjoint sharding to speed up training while maintaining performance.
→ Distributed and parallel implementations further accelerate training across multiple GPUs.
-----
Key Insights 🔍:
→ Memory requirements can be reduced by orders of magnitude by sharding gradient calculations
→ The adjoint method enables parallel computation unlike sequential backpropagation
→ Truncated version achieves linear time complexity vs quadratic for full adjoint sharding
-----
Results 📊:
→ 3X memory reduction with 1.27B parameter model on 1M context length
→ Increased maximum training context from 35K to 100K tokens on 5 AWS P4 instances
→ 64% computation reduction using truncated version with 2000-token window
Share this post