Preventing floating point errors unlocks natural learning in neural networks, by fixing numerical stability issues
This paper reveals how numerical stability issues prevent neural networks from achieving grokking (sudden generalization after overfitting) and proposes solutions to enable grokking without regularization.
https://arxiv.org/abs/2501.04697
🤔 Original Problem:
→ Neural networks often require regularization to achieve grokking, but it's unclear why generalization is delayed and why grokking rarely happens without regularization.
⚡ Key Insights:
→ Without regularization, models hit numerical stability limits due to Softmax Collapse (SC) - floating point errors that zero out gradients
→ Beyond overfitting, gradients align with Naïve Loss Minimization (NLM) direction, which scales logits but doesn't improve predictions
→ This scaling eventually causes numerical issues that prevent learning
🛠️ Solution in this Paper:
→ StableMax: A new activation function that prevents Softmax Collapse by using softer transformation than exponential
→ ⊥ Grad: A training algorithm that removes gradient components aligned with current weights to prevent NLM
→ Combined approach enables quick generalization without requiring regularization
📊 Results:
→ StableMax enables grokking without regularization on modular arithmetic tasks
→ ⊥ Grad achieves faster generalization by avoiding initial overfitting phase
→ Works on both MLPs and Transformers across different dataset sizes (40%, 60%, 70% splits)
Share this post