Want your transformer to run 2x faster? Just plug in this 8-bit attention trick
SageAttention introduces an efficient 8-bit quantization method for transformer attention computation. It achieves 2.1x faster speed than FlashAttention2 while maintaining accuracy across various models. The method works plug-and-play during inference time, requiring no additional training or model modifications.
-----
https://arxiv.org/abs/2410.02367
🔍 Original Problem:
→ Attention computation in transformers becomes a bottleneck due to its quadratic complexity, especially with longer sequences of 8K-128K tokens.
→ Existing quantization methods focus mainly on linear layers, leaving attention unoptimized in high-precision formats.
-----
⚡ Solution in this Paper:
→ SageAttention quantizes attention computation to INT8 format for faster processing.
→ It smooths the Key matrix to handle channel-wise outliers that cause accuracy loss.
→ For PV computation, it uses FP16 with FP16 accumulator instead of 8-bit quantization.
→ The implementation includes a fused kernel combining ROPE with quantization and FlashAttention-style tiling.
-----
💡 Key Insights:
→ Direct 8-bit quantization of attention matrices leads to significant accuracy loss
→ Matrix K exhibits channel-wise outliers requiring special handling
→ INT8 quantization is more precise than FP8 for Q,K matrices
→ Using FP16 accumulator doubles matrix multiplication speed without accuracy loss
-----
📊 Results:
→ 2.1x faster than FlashAttention2 and 2.7x faster than xformers
→ Achieves 340 TOPS on RTX4090, reaching 52% of theoretical INT8 throughput
→ Maintains accuracy across image generation, video generation, and language tasks
→ Zero performance degradation in plug-and-play deployment