Using small models to efficiently train better performing large models.
Making LLM training faster by learning from smaller models first.
📚 https://arxiv.org/abs/2410.18779
🎯 Original Problem:
Training LLMs requires massive computational resources and time. The challenge is to make pre-training more efficient while maintaining or improving model quality.
-----
🔧 Solution in this Paper:
• SALT (Small model Aided Large model Training) - a two-stage pre-training approach:
- Stage 1: Knowledge distillation from Small Language Model (SLM) to LLM
- Stage 2: Standard self-supervised training
• Novel data selection strategy using SLM to identify valuable training examples
- Focuses on challenging but learnable sequences
- Uses SLM's confidence scores to filter examples
• Theoretical framework showing how weaker teacher (SLM) enhances student (LLM) through:
- Variance reduction in easy regions
- Adaptive supervision balancing bias-variance tradeoff
-----
💡 Key Insights:
• Even smaller models can effectively teach larger models in early training stages
• Two-stage approach prevents negative impact from weaker teacher
• Data selection focusing on "challenging but learnable" examples is crucial
• Balancing between utilizing SLM knowledge and avoiding its limitations is key
-----
📊 Results:
• Achieved same performance as baseline with 30% fewer training steps
• Using 1.5B parameter SLM to train 2.8B parameter LLM:
- 28% reduction in wall-clock training time
- Improved few-shot performance across benchmarks
- Better downstream performance after fine-tuning on tasks like arithmetic reasoning and summarization
Share this post