Mathematical proof shows why bigger transformers get smarter: it's all about data's hidden dimensions
Hidden geometry controlling transformer model performance
https://arxiv.org/abs/2411.06646
Original Problem π€:
Understanding why transformer models follow power-law scaling behavior lacks rigorous mathematical explanation, especially when input data lies on low-dimensional manifolds.
-----
Solution in this Paper π§:
β Developed a mathematical framework showing transformers can universally approximate HΓΆlder continuous functions on d-dimensional manifolds using O(log(d)) depth
β Established statistical estimation theory proving generalization error bounds exponentially dependent on intrinsic data dimension d
β Created novel computation method for covering number of transformer network classes
β Validated theoretical predictions through experiments on three datasets: OpenWebText, SQL portion of The Stack, and Tiny Stories
-----
Key Insights π―:
β Intrinsic data dimension d is crucial for determining model performance and scaling behavior
β Transformers need only O(log(d)) layers for function approximation, independent of accuracy
β Scaling laws accurately reflect dataset complexity - simpler datasets show faster convergence
β The estimated intrinsic dimension d remains stable across model architectures and sizes
-----
Results π:
β Close agreement (Β±0.02) between predicted and observed data scaling exponents
β Achieved O(log(d)) depth efficiency, outperforming traditional ReLU networks
β Successfully validated on three datasets with varying complexities, showing faster convergence for simpler datasets