Beautiful Paper 👏
MemoryFormer swaps expensive math with simple table lookups to make LLMs faster 🔥
Why calculate when you can just look it up?
MemoryFormer tackles the computational bottleneck in LLMs by replacing fully-connected layers with memory-based lookup tables. This novel approach uses locality-sensitive hashing to retrieve vectors from in-memory tables, reducing computational complexity while maintaining model performance.
-----
https://arxiv.org/abs/2411.12992
🤔 Original Problem:
LLMs face massive computational demands, with fully-connected layers consuming most resources. Traditional optimization methods like pruning and quantization provide limited benefits, while existing efficient attention mechanisms only address a small part of the total computation.
-----
🔧 Solution in this Paper:
→ MemoryFormer replaces fully-connected layers with Memory Layers that use hash tables to store discrete vectors
→ Input vectors are split into chunks and processed using locality-sensitive hashing to find similar vectors in memory
→ Retrieved vectors are weighted and combined to approximate traditional matrix multiplication
→ The system uses CPU memory resources that typically remain unused during neural network inference
→ Memory Layers reduce computational complexity from O(sdh) to O(sd/τ), where s is sequence length, d is hidden size, and τ is chunk size
-----
💡 Key Insights:
→ Most computation in transformers comes from fully-connected layers, not attention
→ CPU memory resources are underutilized in current deep learning systems
→ Simple hash functions can effectively approximate complex matrix operations
→ Gradient sparsity requires higher learning rates for stable training
-----
📊 Results:
→ Achieves 19% of baseline FLOPs when sequence length is 2048 and hidden size is 2048
→ Outperforms baseline Pythia models on benchmark tasks
→ Better average accuracy compared to other efficient transformer methods like Linformer and Performer
Share this post