Cool idea in this Paper from @Apple Researchers.
Claims that AdamW requires 95% more training tokens (ie, 1.95x as many gradient updates) than their proposed optimizer to reach the same loss.🤯
1.3B parameter AdEMAMix LLM trained on 101B tokens performs comparably to AdamW model trained on 197B tokens (+95%)
📚 https://arxiv.org/abs/2409.03137
Results 📊:
• Consistently outperforms AdamW on language modeling and vision tasks
• Improves optimization stability and convergence speed
• Forgets training data more slowly than AdamW
Original Problem 🔍:
Current momentum-based optimizers use a single Exponential Moving Average (EMA) of gradients, which cannot effectively balance giving high weight to recent gradients and non-negligible weight to older gradients.
-----
Key Insights from this Paper 💡:
• Gradients can remain relevant for tens of thousands of steps
• A single EMA cannot simultaneously emphasize recent and older gradients
• Combining two EMAs with different decay rates can leverage both recent and old gradient information
• Older gradients help converge faster and often to lower minima
• AdEMAMix significantly slows down model forgetting during training
-----
Solution in this Paper 🛠️:
• Introduces AdEMAMix optimizer with two momentum terms:
- Fast EMA (low β) for adapting to recent landscape changes
- Slow EMA (high β) for leveraging very old gradients
• Uses schedulers to gradually increase α (mixing coefficient) and β3 (slow EMA decay rate)
• Modifies Adam update rule:
θ(t) = θ(t-1) - η * ((m̂1(t) + αm2(t)) / √(ν̂(t) + ε) + λθ(t-1))
• Initializes m2 to zero when switching from AdamW mid-training
Share this post