Slash LLM communication costs by 75% while keeping 98% performance
https://arxiv.org/abs/2411.07942
🎯 Original Problem:
Tensor parallelism in LLM inference requires high communication costs between devices. As models grow larger and need more devices, this communication overhead becomes a major bottleneck, especially when serving models at scale.
-----
🔧 Solution in this Paper:
→ The paper introduces a hybrid quantization method combining 4-bit and BF16 precision for communicating features between devices.
→ During calibration, it calculates quantization parameters for each feature using exponential moving averages from a calibration dataset.
→ The method identifies features with large quantization ranges and keeps them in BF16 precision while quantizing others to 4 bits.
→ The selection of high-precision features remains fixed across all sequences and devices.
-----
💡 Key Insights:
→ A small number of features consistently show enormous ranges, causing large quantization errors
→ Tensor parallelism naturally counteracts feature quantization errors through synchronization
→ Quantizing partial sums before synchronization results in errors clustering around zero
→ Only 1/64th of features need to be kept in BF16 to maintain performance
-----
📊 Results:
→ Reduces average bits per value from 16 to 4.2
→ Gemma 2 27B: Preserves 98% of original performance
→ Llama 2 13B: Maintains 99.5% of original performance
→ Mistral NeMo 12B: Retains 97.1% of original performance
Share this post