Why Warmup the Learning Rate? Underlying Mechanisms and Improvements
LLMs can handle much larger learning rates by letting networks naturally reduce their loss landscape sharpness
LLMs can handle much larger learning rates by letting networks naturally reduce their loss landscape sharpness
Deep networks self-adjust their geometry during training, eliminating the need for manual warmup steps, proposed in this paper,
Original Problem 🔍:
Learning rate warmup is widely used in deep learning but lacks clear understanding of why it works and when it's truly necessary. Existing explanations are varied and don't demonstrate the extent of warmup's necessity.
Solution in this Paper 🛠️:
• Analyzed warmup across architectures (FCNs, ResNets, Transformers), optimizers (SGD, Adam), and datasets
• Identified two key mechanisms:
Natural Progressive Sharpening: Network increases sharpness while learning rate rises
Natural Sharpness Reduction: Network reduces sharpness early in training
• Proposed GI-Adam: Initializes Adam's variance using gradient information
• Introduced "catapult mechanism" to estimate initial sharpness and optimize warmup steps
Key Insights 💡:
• Warmup's primary benefit is enabling networks to handle larger target learning rates
• Makes hyperparameter tuning more robust by widening the range of viable learning rates
• Adam is particularly sensitive to large learning rates due to high initial preconditioned sharpness
• Small initializations (like μP) benefit less from warmup than large ones (like SP)
Results 📊:
• GI-Adam consistently outperforms standard Adam across datasets
• Reduced initial preconditioned sharpness by ~2 orders of magnitude
• Successfully eliminated warmup steps in some cases through proper initial learning rate selection
• Pushed training failure boundary to higher target learning rates
🔄 The underlying mechanisms of how warmup works
Two main mechanisms exist depending on initialization:
Natural Progressive Sharpening: Network naturally increases sharpness while learning rate increases, leading to repeated "catapult" cycles where loss spikes trigger sharpness reduction.
Natural Sharpness Reduction: Network naturally reduces sharpness early in training. If learning rate increases too quickly, it triggers isolated catapult events that dramatically reduce sharpness.