Making LLMs more reliable by bridging the training-inference gap
📚 https://arxiv.org/abs/2410.14655
🤖 Original Problem:
LLMs face a critical gap between training and inference. During training, they use ground-truth tokens, but during inference, they rely on their own generated tokens, leading to unpredictable behavior and compounding errors.
-----
🔧 Solution in this Paper:
→ Batch-Scheduled Sampling (BASH): Stochastically mixes ground-truth and model-generated tokens in offline batches during training
→ Reference-Answer-based Correction (RAC): Trains model to self-correct by comparing its generations with reference answers
→ Both methods require zero architectural changes to transformer models
→ Implementation involves initial SFT training followed by BASH/RAC fine-tuning
-----
💡 Key Insights:
→ Training-inference discrepancy can be bridged without complex model modifications
→ Self-correction capability can be built directly into models during training
→ Offline batch processing makes scheduled sampling practical for LLM training
-----
📊 Results:
→ RAC achieves 10.37% win rate vs 8.03% for standard SFT on AlpacaEval 2.0
→ BASH improves accuracy on GSM8K math tasks to 60.22% from 56.76% baseline
→ RAC+DPO achieves 15.85% win rate vs 14.06% for SFT+DPO
Share this post