0:00
/
0:00
Transcript

"MemoryFormer: Minimize Transformer Computation by Removing Fully-Connected Layers"

The podcast on this paper is generated with Google's Illuminate.

Beautiful Paper 👏

MemoryFormer swaps expensive math with simple table lookups to make LLMs faster 🔥

Why calculate when you can just look it up?

MemoryFormer tackles the computational bottleneck in LLMs by replacing fully-connected layers with memory-based lookup tables. This novel approach uses locality-sensitive hashing to retrieve vectors from in-memory tables, reducing computational complexity while maintaining model performance.

-----

https://arxiv.org/abs/2411.12992

🤔 Original Problem:

LLMs face massive computational demands, with fully-connected layers consuming most resources. Traditional optimization methods like pruning and quantization provide limited benefits, while existing efficient attention mechanisms only address a small part of the total computation.

-----

🔧 Solution in this Paper:

→ MemoryFormer replaces fully-connected layers with Memory Layers that use hash tables to store discrete vectors

→ Input vectors are split into chunks and processed using locality-sensitive hashing to find similar vectors in memory

→ Retrieved vectors are weighted and combined to approximate traditional matrix multiplication

→ The system uses CPU memory resources that typically remain unused during neural network inference

→ Memory Layers reduce computational complexity from O(sdh) to O(sd/τ), where s is sequence length, d is hidden size, and τ is chunk size

-----

💡 Key Insights:

→ Most computation in transformers comes from fully-connected layers, not attention

→ CPU memory resources are underutilized in current deep learning systems

→ Simple hash functions can effectively approximate complex matrix operations

→ Gradient sparsity requires higher learning rates for stable training

-----

📊 Results:

→ Achieves 19% of baseline FLOPs when sequence length is 2048 and hidden size is 2048

→ Outperforms baseline Pythia models on benchmark tasks

→ Better average accuracy compared to other efficient transformer methods like Linformer and Performer

Discussion about this video