0:00
/
0:00
Transcript

"Cut Your Losses in Large-Vocabulary Language Models"

The podcast on this paper is generated with Google's Illuminate.

Cut Cross-Entropy (CCE): Trimming the fat from LLM memory consumption.

Memory-efficient cross-entropy: the key to scaling up LLM training.

This paper introduces Cut Cross-Entropy (CCE), a method to reduce memory usage in LLM training. CCE computes cross-entropy loss without storing logits for all tokens, dramatically decreasing memory footprint. It achieves this by computing logits on-the-fly and leveraging softmax sparsity, enabling larger vocabularies and batch sizes.

-----

https://arxiv.org/abs/2411.09009

🔍 Original Problem:

The cross-entropy loss layer in LLM training consumes excessive memory, limiting vocabulary and batch sizes.

-----

💡 Solution in this Paper:

→ CCE computes cross-entropy loss without materializing logits for all tokens in global memory.

→ It only computes the logit for the correct token and evaluates log-sum-exp over all logits on-the-fly.

→ A custom kernel performs matrix multiplications and log-sum-exp reduction over the vocabulary in flash memory.

→ CCE leverages softmax sparsity to skip negligible gradient computations, improving throughput.

→ The method uses gradient filtering and vocabulary sorting to optimize memory access patterns.

-----

🔑 Key Insights from this Paper:

→ Cross-entropy loss consumes up to 90% of memory in modern LLM training

→ Softmax sparsity can be exploited to reduce computation without loss of precision

→ Memory-efficient implementations can enable larger vocabularies and batch sizes

-----

📊 Results:

→ Reduced memory footprint of loss computation from 24 GB to 1 MB for Gemma 2 (2B) model

→ Decreased total training-time memory consumption of classifier head from 28 GB to 1 GB

→ Achieved memory reduction without sacrificing training speed or convergence

→ Enabled 1.5x to 10x increase in batch size for various frontier models

Discussion about this video

User's avatar