0:00
/
0:00
Transcript

"Adjoint sharding for very long context training of state space models"

Below podcast is generated with Google's Illuminate.

Adjoint sharding lets LLMs train on million-token sequences without exploding GPU memory.

The paper introduces adjoint sharding - a novel technique that drastically reduces memory usage for training LLMs on long sequences by optimizing gradient calculations.

-----

https://arxiv.org/abs/2501.00692

Original Problem 🤔:

→ Training LLMs on very long contexts (>1M tokens) hits severe GPU memory limitations and slow training speeds.

→ Current methods force training with short contexts (few thousand tokens) and use workarounds during inference.

-----

Solution in this Paper 💡:

→ Adjoint sharding splits gradient computation during training into independent vector-Jacobian products.

→ It uses the adjoint method instead of backpropagation to compute equivalent gradients while using much less memory.

→ The method introduces truncated adjoint sharding to speed up training while maintaining performance.

→ Distributed and parallel implementations further accelerate training across multiple GPUs.

-----

Key Insights 🔍:

→ Memory requirements can be reduced by orders of magnitude by sharding gradient calculations

→ The adjoint method enables parallel computation unlike sequential backpropagation

→ Truncated version achieves linear time complexity vs quadratic for full adjoint sharding

-----

Results 📊:

→ 3X memory reduction with 1.27B parameter model on 1M context length

→ Increased maximum training context from 35K to 100K tokens on 5 AWS P4 instances

→ 64% computation reduction using truncated version with 2000-token window

Discussion about this video