0:00
/
0:00
Transcript

"Towards Low-bit Communication for Tensor Parallel LLM Inference"

The podcast on this paper is generated with Google's Illuminate.

Slash LLM communication costs by 75% while keeping 98% performance

https://arxiv.org/abs/2411.07942

🎯 Original Problem:

Tensor parallelism in LLM inference requires high communication costs between devices. As models grow larger and need more devices, this communication overhead becomes a major bottleneck, especially when serving models at scale.

-----

🔧 Solution in this Paper:

→ The paper introduces a hybrid quantization method combining 4-bit and BF16 precision for communicating features between devices.

→ During calibration, it calculates quantization parameters for each feature using exponential moving averages from a calibration dataset.

→ The method identifies features with large quantization ranges and keeps them in BF16 precision while quantizing others to 4 bits.

→ The selection of high-precision features remains fixed across all sequences and devices.

-----

💡 Key Insights:

→ A small number of features consistently show enormous ranges, causing large quantization errors

→ Tensor parallelism naturally counteracts feature quantization errors through synchronization

→ Quantizing partial sums before synchronization results in errors clustering around zero

→ Only 1/64th of features need to be kept in BF16 to maintain performance

-----

📊 Results:

→ Reduces average bits per value from 16 to 4.2

→ Gemma 2 27B: Preserves 98% of original performance

→ Llama 2 13B: Maintains 99.5% of original performance

→ Mistral NeMo 12B: Retains 97.1% of original performance

Discussion about this video