"QuEST: Stable Training of LLMs with 1-Bit Weights and Activations"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2502.05003
The challenge of high computational costs in LLMs restricts their efficiency. Quantization Aware Training struggles to maintain accuracy at very low bit-widths, with 8-bit quantization often seen as the limit for practical training.
This paper introduces QuEST, a novel Quantization Aware Training method. QuEST achieves stable training of LLMs using extremely low 1-bit weights and activations. QuEST also pushes the Pareto frontier to 4-bit quantization, surpassing the accuracy of standard 16-bit formats at comparable model sizes.
-----
📌 QuEST's Hadamard Transform preprocessing is a key innovation. It reshapes weight distributions for effective quantization. This technique unlocks surprisingly stable and accurate training even at 1-bit, pushing limits of low-precision Quantization Aware Training.
📌 The Trust Gradient Estimator in QuEST directly tackles the core challenge of noisy gradients in low-bit training. By selectively trusting gradients based on quantization error, it stabilizes optimization. This is crucial for 1-bit training viability.
📌 QuEST practically demonstrates that 4-bit LLMs can surpass Brain Float 16 (BF16) performance. This redefines the Pareto frontier for efficient LLM training. It immediately enables smaller, faster, and cheaper high-accuracy models.
----------
Methods Explored in this Paper 🔧:
→ QuEST enhances Quantization Aware Training through two core innovations.
→ It introduces a novel approach to distribution fitting in the forward pass. This involves applying a Hadamard Transform for normalization before quantizing weights and activations. This Hadamard Transform helps to shape the distribution of weights to be closer to Gaussian, which is more suitable for quantization.
→ QuEST employs Mean Squared Error optimal fitting to determine the best quantization grid for the transformed distributions. This ensures accurate and fast quantization.
→ For the backward pass, QuEST presents a Trust Gradient Estimator. This estimator minimizes the discrepancy between gradients calculated with quantized parameters and full-precision gradients.
→ The Trust Gradient Estimator selectively reduces the influence of gradients from weight components that have high quantization errors in the forward pass. This strategy stabilizes training by mitigating the impact of outlier quantization errors on gradient updates.
-----
Key Insights 💡:
→ QuEST enables stable training of LLMs with just 1-bit weights and activations. This was not previously considered feasible with standard Quantization Aware Training methods.
→ Models trained with QuEST using 4-bit weights and activations achieve superior accuracy compared to models using Brain Float 16 (BF16) format, especially when considering model size and inference cost. QuEST shifts the Pareto-optimal frontier for LLM training to lower bit-widths.
→ Hadamard Transform preprocessing significantly improves the effectiveness of the trust estimation method, contributing to stable training and better performance.
→ Analysis of efficiency factors reveals that 4-bit quantization offers the optimal balance between model size, computational cost, and accuracy in overtraining scenarios.
-----
Results 📊:
→ QuEST trained models with 4-bit weights and activations outperform BF16 models of nearly 4 times larger size in terms of accuracy, as shown in Figure 1.
→ QuEST demonstrates lower perplexity than LSQ (Learned Step Size Quantization), especially at lower bit-widths, as shown in Figure 3. For instance, at W4A4 configuration, QuEST achieves 29.08 perplexity, while LSQ reaches 30.27.
→ GPU kernel implementation of QuEST achieves per-layer inference speedups from 1.2× to 2.4× for an 800M parameter model and 2.3× to 3.9× for a 7B parameter model compared to BF16, as depicted in Figure 6.
→ End-to-end inference speedup of 1.3× to 1.5× is observed for QuEST INT4 over BF16 in prefill stage using an 800M parameter model on RTX 4090, as shown in Figure 7.
→ Zero-shot evaluation on HellaSWAG benchmark shows comparable accuracy between QuEST 4-bit model (39.22%) and BF16 model (39.52%) of 800M parameters, as indicated in Table 2, suggesting near lossless quantization.