Memory-efficient LLM training achieved by directly updating quantized weights without full-precision storage.
Direct Quantized Training enables LLMs to train with low-precision weights directly, eliminating the need for high-precision weight storage during training, potentially reducing memory usage by up to 95% compared to traditional methods.
-----
https://arxiv.org/abs/2412.04787
🤖 Original Problem:
Current quantization methods for LLMs still require storing high-precision weights during training, consuming substantial memory. Even BitNet, which uses binary/ternary weights, maintains full-precision matrices for gradient updates.
-----
🔧 Solution in this Paper:
→ Direct Quantized Training (DQT) maintains only low-precision weights throughout training, eliminating high-precision storage.
→ DQT uses stochastic rounding to directly update quantized weights during backpropagation.
→ The method keeps weight matrices fixed at n-bit precision during the entire process, from 1.58-bit to 8-bit options.
→ Stochastic rounding probabilistically converts high-precision gradients to low-precision values based on their distance from target values.
-----
💡 Key Insights:
→ Models can converge even with ternary weights during training
→ Higher bit-width (8-bit) implementations show more stability
→ DQT enables flexible deployment with ternary inference capability
→ Memory efficiency improves as no high-precision weights are stored
-----
📊 Results:
→ 8-bit DQT achieves only 5% loss degradation compared to BitNet b1.58
→ Successfully trains with ternary values (1.58-bit)
→ 4GB memory requirement for 1B parameters reduces to 0.2GB with ternary weights
Share this post