By rethinking GPU memory patterns, MARLIN makes LLM inference almost 4 times faster
Through optimized mixed-precision kernels and GPU memory management
MARLIN achieves near-optimal 3.87x speedup up to batch size 32 on NVIDIA A10 🤯
https://arxiv.org/abs/2408.11743
Key Insights 💡:
• LLM inference remains memory-bound even at larger batch sizes
• Careful pipelining and partitioning can maintain quantization speedups
• GPU hardware features can be leveraged for efficient mixed-precision computations
• Sparsity can provide additional speedups on top of quantization
Solution in this Paper 🛠️:
• MARLIN: Mixed-precision Auto-Regressive LINear kernels
- Optimized for 4-bit quantized weights and FP16 activations
- Utilizes Tensor Cores and asynchronous memory operations
- Implements multi-level pipelining for latency hiding
- Employs striped partitioning for efficient workload distribution
- Supports group-wise quantization for accuracy preservation
• Sparse-MARLIN: Extension supporting 2:4 structured sparsity
- Leverages Sparse Tensor Core Units (SPTCs)
- Introduces specialized data layouts for sparse computations
Results 📊:
• Maintains significant speedups (1.5x) even at batch size 128
• End-to-end 2.8x speedup in vLLM integration for Llama-2-7B
• Sparse-MARLIN provides additional 1.2x speedup over MARLIN
• Time Per Output Token (TPOT) reduced by 3.3x for Llama-2-7B on NVIDIA RTX A6000
Share this post