0:00
/
0:00
Transcript

"SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training"

Generated below podcast on this paper with Google's Illuminate.

A new optimizer that makes LLM training both stable and memory-efficient.

SPAM (Spike-Aware Adam with Momentum Reset) enables stable LLM training by managing gradient spikes through momentum reset and adaptive clipping, while reducing memory usage via sparse momentum.

-----

https://arxiv.org/abs/2501.06842

Original Problem 🔍:

LLM training faces severe instability due to gradient spikes that can be 1000× larger than typical gradients. Current solutions rely on costly manual interventions like checkpoint restarts and data skipping.

-----

Solution in this Paper 🛠️:

→ SPAM introduces momentum reset that periodically zeros out accumulated gradients to prevent spike propagation.

→ A spike-aware clipping mechanism identifies and scales down anomalous gradients while preserving their directional information.

→ Sparse momentum maintains only a subset of momentum terms, reducing memory requirements.

→ SPAM dynamically adjusts clipping thresholds based on gradient statistics during training.

-----

Key Insights from this Paper 🔬:

→ Gradient spikes primarily occur in LayerNorm layers despite being the smallest parameter group

→ Spikes persist across model sizes from 60M to 1B parameters

→ Random selection outperforms magnitude-based methods for sparse momentum

-----

Results 📊:

→ Outperforms Adam by 3.63 perplexity points on LLaMA-60M pre-training

→ Reduces memory footprint by 75% while maintaining performance

→ Achieves 81.04% accuracy on ImageNet with ConvNeXt-T

Discussion about this video