"Matryoshka Quantization"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2502.06786
Quantizing LLMs for efficient deployment leads to a trade-off in model quality, especially at very low bit precisions like int2, often requiring multiple models for different precision needs.
This paper proposes Matryoshka Quantization (MatQuant) to address this by training a single versatile model adaptable to various precision levels.
-----
📌 MatQuant smartly exploits integer bit's nested structure. Joint training optimizes shared weights for int8, int4, and int2. This avoids training separate models per precision, saving resources.
📌 A key advantage is on-demand precision. From one MatQuant model, you can extract int8, int4, int2, even int6/int3 versions via slicing. This offers deployment flexibility without retraining.
📌 Mix'n'Match and interpolation features are highly practical. Deployments can dynamically select optimal layer-wise bit-widths or interpolated precisions like int3/int6 based on hardware and latency needs.
----------
Methods Explored in this Paper 🔧:
→ MatQuant is introduced as a multi-scale training technique for model quantization.
→ It leverages the inherent nested structure within integer data types like int8, int4, and int2, termed the "Matryoshka" structure.
→ MatQuant jointly optimizes model weights across multiple integer precision levels during training.
→ It achieves this by representing model parameters using shared most significant bits across different precision levels.
→ A combined loss function is used to optimize performance at each target precision simultaneously.
→ This method allows for the extraction of lower bit-width models, such as int4 or int2, directly from an int8-quantized model by slicing the most significant bits.
→ MatQuant is designed to be compatible with learning-based quantization techniques like Quantization Aware Training (QAT) and OmniQuant.
-----
Key Insights 💡:
→ MatQuant enables the creation of a single quantized model capable of operating effectively at different bit-widths, offering a spectrum of accuracy-versus-cost options.
→ Int2 precision models extracted using MatQuant achieve up to 10% higher accuracy compared to standard int2 quantization methods like QAT or OmniQuant.
→ MatQuant shifts the quantized weight distribution towards higher values, which is particularly beneficial for improving int2 performance.
→ The technique facilitates bit-width interpolation, allowing for the extraction of accurate models at intermediate precisions like int6 and int3 without explicit training for those precisions.
-----
Results 📊:
→ Int2 models from MatQuant show up to 8% accuracy improvement on downstream tasks compared to baseline int2 quantization.
→ An int2 FFN-quantized Gemma-2 9B model using MatQuant is more accurate than an int8 FFN-quantized Gemma-2 2B model with the same recipe.
→ MatQuant achieves comparable accuracy to baseline methods for int8 and int4 quantization while significantly improving int2 accuracy, as shown in Figure 1b of the paper.