Simple learning rate scaling at start beats complex adaptive methods like Adam.
SGD-SaI (SGD with Scaling at Initialization enhances stochastic gradient descent by scaling learning rates at initialization based on gradient signal-to-noise ratios, eliminating the need for adaptive methods like Adam while using 50% less memory.
-----
https://arxiv.org/abs/2412.11768
🤔 Original Problem:
Adaptive gradient methods like Adam are essential for training Transformers but consume excessive memory by storing momentum states, limiting model scalability. For a 7B parameter model, Adam requires 50GB just for optimizer states.
-----
🔧 Solution in this Paper:
→ SGD-SaI calculates gradient signal-to-noise ratio (g-SNR) for each parameter group only at initialization.
→ These g-SNR values guide learning rate scaling for different parameter blocks throughout training.
→ The method maintains constant preconditioned learning rate scales without computing adaptive second-order momentum.
→ It works with PyTorch's default parameter partitioning, requiring no complex grouping strategies.
-----
💡 Key Insights:
→ g-SNR values remain stable during training while varying across parameter groups
→ Initial gradient characteristics sufficiently capture parameter-specific learning dynamics
→ Adaptive methods' complexity can be replaced by smart initialization-time scaling
-----
📊 Results:
→ Reduces optimizer memory by 50% compared to AdamW (25.15GB savings for Llama2-7B)
→ Matches/exceeds AdamW performance on ViT ImageNet1K (72.92% vs 73.04% accuracy)
→ 3x faster optimizer updates than Adam-mini for GPT2-small training
Share this post