New Transformer architecture modifications from NVIDIA researchers.
nGPT: A hypersphere-based Transformer achieving 4-20x faster training and improved stability for LLMs.
📚 https://arxiv.org/abs/2410.01131
Proposals in this Paper 🛠️:
• Normalized Transformer (nGPT) architecture
• All vectors normalized to unit norm on hypersphere
• Learnable eigen learning rates control hidden state updates
• Removal of LayerNorm/RMSNorm layers
• Introduction of scaling factors for logits, query/key vectors, and MLP states
• Elimination of weight decay and learning rate warmup
-----
Key Insights from this Paper 💡:
• nGPT learns 4-20x faster than standard Transformers
• Hyperspherical representation improves stability and embedding separability
• Transformer layers act as optimization steps on a hypersphere
• Eigen learning rates control the impact of each block's updates
• nGPT handles longer contexts without modifying positional encodings
-----
Results 📊:
• 4x faster training for 1k context length
• 10x faster training for 4k context length
• 20x faster training for 8k context length
• Similar or better performance on downstream tasks with less training
• More stable performance when extrapolating to longer sequences
Share this post