A new optimizer that makes LLM training both stable and memory-efficient.
SPAM (Spike-Aware Adam with Momentum Reset) enables stable LLM training by managing gradient spikes through momentum reset and adaptive clipping, while reducing memory usage via sparse momentum.
-----
https://arxiv.org/abs/2501.06842
Original Problem 🔍:
LLM training faces severe instability due to gradient spikes that can be 1000× larger than typical gradients. Current solutions rely on costly manual interventions like checkpoint restarts and data skipping.
-----
Solution in this Paper 🛠️:
→ SPAM introduces momentum reset that periodically zeros out accumulated gradients to prevent spike propagation.
→ A spike-aware clipping mechanism identifies and scales down anomalous gradients while preserving their directional information.
→ Sparse momentum maintains only a subset of momentum terms, reducing memory requirements.
→ SPAM dynamically adjusts clipping thresholds based on gradient statistics during training.
-----
Key Insights from this Paper 🔬:
→ Gradient spikes primarily occur in LayerNorm layers despite being the smallest parameter group
→ Spikes persist across model sizes from 60M to 1B parameters
→ Random selection outperforms magnitude-based methods for sparse momentum
-----
Results 📊:
→ Outperforms Adam by 3.63 perplexity points on LLaMA-60M pre-training
→ Reduces memory footprint by 75% while maintaining performance
→ Achieves 81.04% accuracy on ImageNet with ConvNeXt-T
Share this post