Want Qwen 2.5 performance on a single GPU? ARWKV distills Transformer power into an RNN.
Forget pretraining, distillation unlocks RNN potential to match Transformer LLMs like Qwen 2.5.
This paper introduces ARWKV, a novel RNN-based LLM distilled from Transformer models. It uses RWKV-7 attention to enhance RNN expressiveness and state tracking, achieving performance comparable to Qwen 2.5 with reduced training resources.
RWKV (RNN Weighted Key-Value) is reborn from Transformer's rib, offering comparable LLM performance with RNN efficiency.
-----
Paper - https://arxiv.org/abs/2501.15570
Original Problem 🤔:
→ Transformer-based LLMs like Qwen 2.5 demand extensive GPU resources for pretraining, hindering academic research.
→ Linear RNNs offer efficiency but traditionally lack expressiveness compared to Transformers, especially in long-context tasks.
-----
Solution in this Paper 💡:
→ The paper proposes ARWKV, an RNN-Attention based LLM.
→ ARWKV is distilled from Transformer models like Qwen 2.5.
→ It replaces Transformer self-attention with RWKV-7 time mixing modules.
→ Stage 1 involves attention alignment, training RWKV-7 to mimic Transformer attention.
→ Stage 2 uses knowledge distillation to transfer knowledge from a larger Transformer LLM to ARWKV.
→ Stage 3 employs supervised fine-tuning and Direct Preference Optimization for context extension and alignment.
→ This method enables training a 7B parameter model on a single GPU, significantly reducing resource needs.
-----
Key Insights from this Paper 🔑:
→ Attention alignment is crucial for successful Transformer-to-RNN distillation.
→ RWKV-7 time mixing can effectively capture Transformer attention patterns.
→ Distillation allows knowledge transfer from large LLMs to smaller, efficient RNN models.
→ Using float16 inference with ARWKV improves performance compared to bfloat16 training.
→ Direct knowledge transfer from very large teacher models (32B to 7B) without careful MLP adaptation can be suboptimal.
-----
Results 📈:
→ ARWKV achieves 62.41 on MMLU benchmark after stage 2 distillation.
→ ARWKV achieves 68.67 on WinoGrande benchmark after stage 2 distillation.
→ ARWKV achieves 52.22 on Arc-c benchmark after stage 2 distillation.
Share this post