"Exploring Grokking: Experimental and Mechanistic Investigations"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2412.10898
The paper addresses the perplexing phenomenon of "grokking" in neural networks. Networks initially memorize training data but suddenly generalize to unseen data after extended training.
This paper experimentally investigates grokking. It explores how training data size, model architecture, and optimization methods influence this generalization behavior.
-----
📌 Transformer architecture exhibits grokking, unlike MLP and LSTM in this study. This highlights architectural bias towards generalization in specific tasks like modular arithmetic.
📌 Weight decay with AdamW is crucial for observing grokking. Regularization methods significantly influence the generalization behavior and should be considered for grokking phenomena.
📌 Training data fraction around 50% is key for pronounced grokking. Data set size is not merely about task difficulty but also about enabling specific generalization dynamics.
----------
Methods Explored in this Paper 🔧:
→ The paper explores modular addition as a task to study grokking. Two encoding methods are used for input data.
→ Experiments vary the fraction of training data used, from small to large portions of the total dataset. This helps observe how data size affects generalization.
→ Three neural network models are tested: Transformer, LSTM, and MLP. This comparison examines if model architecture plays a role in grokking.
→ Different optimization algorithms are applied, including Adam and AdamW. The impact of weight decay and regularization is investigated.
→ The Transformer model used is a decoder-only architecture with 2 layers and 128 width. A simplified Transformer with one MLP layer is also used. LSTM has a hidden size of 20, and MLP has 2 layers with hidden size 512.
-----
Key Insights 💡:
→ Grokking is most prominent when around 50% of the training data is used. Smaller training data fractions can make the task too difficult for generalization.
→ Transformer models exhibit grokking, while MLP and LSTM models, in these experiments, show more traditional generalization patterns. This suggests architecture is a factor.
→ Weight decay, using AdamW optimizer, improves generalization. Regularization techniques appear to influence grokking.
→ The paper discusses structured representation and implicit biases as potential underlying mechanisms for grokking.
-----
Results 📊:
→ With 45% and 60% training data fraction, Transformer models showed a clear gap between training and validation accuracy, with generalization occurring around 6000 epochs.
→ At 30% training data, using simplified encoding and Transformer, validation accuracy surged to 100% around 9.1k steps after training accuracy reached 100% at 200 steps.
→ MLP models did not reach 100% training accuracy even after 20,000 steps, indicating Transformers perform better on this task.
→ Experiments with different optimizers (Figure 8) show AdamW with weight decay achieves better generalization compared to Adam and variations with noise.