Cut Cross-Entropy (CCE): Trimming the fat from LLM memory consumption.
Memory-efficient cross-entropy: the key to scaling up LLM training.
This paper introduces Cut Cross-Entropy (CCE), a method to reduce memory usage in LLM training. CCE computes cross-entropy loss without storing logits for all tokens, dramatically decreasing memory footprint. It achieves this by computing logits on-the-fly and leveraging softmax sparsity, enabling larger vocabularies and batch sizes.
-----
https://arxiv.org/abs/2411.09009
🔍 Original Problem:
The cross-entropy loss layer in LLM training consumes excessive memory, limiting vocabulary and batch sizes.
-----
💡 Solution in this Paper:
→ CCE computes cross-entropy loss without materializing logits for all tokens in global memory.
→ It only computes the logit for the correct token and evaluates log-sum-exp over all logits on-the-fly.
→ A custom kernel performs matrix multiplications and log-sum-exp reduction over the vocabulary in flash memory.
→ CCE leverages softmax sparsity to skip negligible gradient computations, improving throughput.
→ The method uses gradient filtering and vocabulary sorting to optimize memory access patterns.
-----
🔑 Key Insights from this Paper:
→ Cross-entropy loss consumes up to 90% of memory in modern LLM training
→ Softmax sparsity can be exploited to reduce computation without loss of precision
→ Memory-efficient implementations can enable larger vocabularies and batch sizes
-----
📊 Results:
→ Reduced memory footprint of loss computation from 24 GB to 1 MB for Gemma 2 (2B) model
→ Decreased total training-time memory consumption of classifier head from 28 GB to 1 GB
→ Achieved memory reduction without sacrificing training speed or convergence
→ Enabled 1.5x to 10x increase in batch size for various frontier models
Share this post