Generate text 8x faster than GPT while matching its quality through smart self-distillation.
📚 https://arxiv.org/abs/2410.21035
🤖 Original Problem:
Current autoregressive LLMs generate text one token at a time, causing noticeable latency. This becomes a major bottleneck when applications need search, planning, or reranking with multiple completions.
-----
🔧 Solution in this Paper:
• Introduces Self-Distillation Through Time (SDTT) for discrete diffusion models
• Generates 32+ tokens simultaneously while exceeding autoregressive model quality
• Uses Kullback-Leibler Divergence (KLD) as the distillation measure
• Reduces inference steps by 32-64x through iterative distillation rounds
• Operates without activation caching, suggesting room for further speed gains
• Successfully scales to models up to 860M parameters
-----
💡 Key Insights:
• KLD outperforms Mean-Squared Error and Total Variation Distance for distillation
• Distilling more than 2 steps at once weakens student model performance
• 5-10k training steps per distillation round is sufficient
• Preserves teacher model's natural language understanding capabilities
• Does not rely on deterministic mappings like DDIM
-----
📊 Results:
• Achieves 8x faster generation than AR models using KV caching
• Matches/exceeds GPT-2 performance with nucleus sampling using only 32 steps
• Maintains teacher model's performance on LAMBADA benchmark
• Demonstrates effectiveness up to 860M parameter models
• Reduces inference steps from 1024 to 16-32 while preserving quality
Share this post