Pre-training Distillation for Large Language Models: A Design Space Exploration
Smart logit compression and loss scheduling lets small LLMs learn efficiently from bigger siblings 4000x lighter Pre-training distillation with compressed logits
Smart logit compression and loss scheduling lets small LLMs learn efficiently from bigger siblings
4000x lighter Pre-training distillation with compressed logits
Original Problem 🎯:
Knowledge distillation for LLMs typically focuses on post-training phase. No systematic exploration exists for applying distillation during pre-training to enhance smaller models using teacher model's knowledge.
Solution in this Paper 🔧:
• Introduces pre-training distillation (PD) across four key dimensions:
Logits processing: Truncation and normalization of teacher model outputs
Loss selection: Combination of distillation and language modeling loss
Scaling behavior: Impact of student/teacher model sizes and training data
Online vs offline logits generation strategy
• Uses top-p-k truncation method to reduce storage space by 4,000x
• Employs Warmup-Stable-Decay scheduler for loss combination
Key Insights 💡:
• Larger student models (>10% of teacher size) benefit more from pre-training distillation
• Larger teacher models don't guarantee better results due to capacity gaps
• Online logits generation shows promise but performs slightly below offline approach
• Temperature for normalization should be kept low for optimal results
• KL divergence and negative log-likelihood losses perform similarly well
Results 📊:
• Initial experiment with 1.9B student model showed 1.6% average improvement across datasets
• WSD-α scheduler with WSD learning rate achieved 8.0% improvement
• 6.8B student model achieved 48.0% average score vs 44.9% baseline
• Storage reduction from 58.6PB to 15TB using logits truncation
• Consistent improvements maintained with 500B tokens of training data
💡 The optimal loss configuration discovered for pre-training distillation
The best loss configuration involves:
Using Kullback-Leibler divergence or negative log-likelihood loss
Avoiding MSE loss which causes significant performance drops
Employing Warmup-Stable-Decay (WSD) method to schedule KD loss proportion
Pairing with WSD learning rate scheduler