INT8 quantization meets FlashAttention to supercharge inference speed on Ampere GPUs
First-ever fully INT8 attention operator that's 72% faster than FP16
https://arxiv.org/abs/2409.16997
🤖 Original Problem:
Self-attention in LLMs faces quadratic time and memory complexity challenges with sequence length. While FlashAttention improves this through GPU memory hierarchy optimization, existing quantization methods aren't compatible with FlashAttention's workflow, especially on widely-used NVIDIA Ampere GPUs.
-----
🔧 Solution in this Paper:
INT-FlashAttention introduces a novel token-level quantization architecture that:
→ Implements Q, K, V matrices in fully INT8 format
→ Uses INT8 GEMM kernels for all matrix multiplications
→ Integrates seamlessly with FlashAttention's online softmax workflow
→ Employs token-level quantization for Q and K matrices
→ Maintains tensor-level quantization for V matrix
-----
💡 Key Insights:
→ First attention operator with fully INT8 input
→ Token-level quantization framework adaptable to other formats like INT4
→ Significant performance boost for Ampere GPUs (20% of supercomputer compute power)
→ Efficient scaling and dequantization process in online softmax workflow
-----
📊 Results:
→ 72% faster inference speed vs FlashAttention-FP16
→ 82% smaller quantization error vs FlashAttention-FP8
→ Speed improvements: 31%, 52%, 66%, 72%, 73% for sequence lengths 1k, 2k, 4k, 8k, 16k
→ Up to 5.6x smaller Mean Relative Error under uniform-distributed activations
Share this post