MrT5: Dynamic Token Merging for Efficient Byte-level Language Models
Finally, a way to make byte-level models efficient through learned token compression.
Finally, a way to make byte-level models efficient through learned token compression.
MrT5 makes makes byte-level models 3x faster without sacrificing performance by dynamically merging tokens while preserving accuracy.
Basically teaching the model to delete unnecessary bytes makes everything run way faster.
🎯 Original Problem:
Byte-level models like ByT5 avoid tokenization issues but suffer from significantly longer sequence lengths, making training and inference inefficient. This leads to slower processing and higher computational costs compared to subword tokenization models.
🔧 Solution in this Paper:
• Introduces MrT5 (MergeT5) - a modified ByT5 with dynamic token deletion mechanism
• Implements a learnable delete gate at a fixed early encoder layer (layer 3)
• During training: Uses soft deletion via attention masking
• During inference: Applies hard deletion by removing tokens below threshold
• Features tunable deletion regularizer to control compression rate
• Preserves contextual information from deleted tokens in remaining sequence
💡 Key Insights:
• Token deletion is most effective when placed in early encoder layers
• Model learns meaningful deletion patterns specific to tasks
• Zero-shot transfer possible across languages using Latin script
• Multilingual training significantly improves cross-script performance
• Can be added to pre-trained models with minimal fine-tuning
📊 Results:
• Reduces sequence lengths by up to 80% while maintaining ByT5-level accuracy
• Improves inference runtime by 25-55%
• Achieves over 50% sequence length reduction across 15 languages
• Maintains comparable accuracy on XNLI and character-level tasks
• Outperforms random and fixed deletion baselines
How MrT5's deletion mechanism works
Uses a deletion gate after a fixed encoder layer (typically layer 3)
During training: Soft deletion using attention masking
During inference: Hard deletion by removing tokens below a threshold
Deletion rate controlled by a tunable regularizer
Preserves contextual information from deleted tokens in remaining tokens
🚀 Key advantages of MrT5's approach
Reduces sequence lengths by up to 80% while maintaining comparable accuracy to ByT5
Improves inference runtime significantly (25-55% faster)
Works zero-shot across multiple languages when trained on English
Can be added to pre-trained models with minimal fine-tuning
Learns meaningful deletion patterns tailored to specific tasks