"MARS: Unleashing the Power of Variance Reduction for Training Large Models"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2411.10438
The paper addresses the problem of inefficient training of large models due to the high variance inherent in stochastic gradient methods, especially for LLMs. Existing variance reduction techniques have not been widely adopted in deep learning.
This paper introduces MARS, a unified optimization framework. MARS combines preconditioned gradient methods with a scaled stochastic recursive momentum technique to reduce variance and accelerate training.
-----
📌 MARS addresses LLM training inefficiency by smartly merging variance reduction with preconditioned gradients. This combination accelerates convergence and improves model quality over AdamW.
📌 MARS framework offers practical benefits. It reduces training tokens needed to reach target loss. For GPT-2 large, MARS saves 22 billion tokens compared to AdamW.
📌 MARS is a versatile optimization framework. It is not just one algorithm but a family, adaptable to various preconditioning techniques like AdamW, Lion, and Shampoo.
----------
Methods Explored in this Paper 🔧:
→ MARS framework is proposed. It integrates variance reduction into adaptive gradient methods.
→ A scaled Stochastic Recursive Momentum (STORM) is used. This provides a variance-reduced gradient estimator.
→ Preconditioned updates are incorporated. These approximate second-order Newton's method. This combination aims for efficient optimization.
→ MARS is versatile. It supports full matrix or diagonal Hessian approximations.
→ Three MARS instances are introduced: MARS-AdamW, MARS-Lion, and MARS-Shampoo. These adapt MARS to existing preconditioning methods.
-----
Key Insights 💡:
→ MARS framework effectively combines variance reduction and preconditioning.
→ A scaling parameter in STORM allows adjustable variance reduction strength.
→ MARS achieves a better convergence rate than AdamW theoretically.
→ Empirical results show MARS outperforms AdamW in LLM training.
-----
Results 📊:
→ MARS-AdamW reaches a validation loss of 2.58 on GPT-2 large using 28 billion tokens. AdamW requires 50 billion tokens for the same loss.
→ MARS achieves a final validation loss of 2.51 on GPT-2 large, lower than AdamW's 2.58.
→ On the Hellaswag downstream task, MARS improves accuracy to 44.64%. AdamW achieves 41.70%.