Self-attention is actually doing kernel PCA under the hood - now we can make it robust
So how does Self-attention works: it's kernel PCA in disguise, as proposed in this paper.
📚 https://arxiv.org/abs/2406.13762
🤔 Original Problem:
Self-attention in transformers has been developed through heuristics and experience, lacking a systematic theoretical framework to understand its inner workings and improve robustness against data corruption.
-----
🔧 Solution in this Paper:
→ Derives self-attention from kernel Principal Component Analysis (kernel PCA), showing it projects query vectors onto principal component axes of key matrix
→ Introduces RPC-Attention (Attention with Robust Principal Components) that uses Principal Component Pursuit to handle corrupted data
→ Implements PAP (Principal Attention Pursuit) algorithm that iteratively recovers clean data from corrupted inputs
-----
💡 Key Insights:
→ Self-attention mathematically performs kernel PCA in feature space
→ Value matrix captures eigenvectors of Gram matrix of key vectors
→ Number of principal components used must be ≤ number of data points
→ Different value matrix parameterizations lead to different attention architectures
-----
📊 Results:
→ RPC-Attention outperforms baseline on ImageNet-1K by 1% accuracy
→ Shows 3% improvement in AUPR on ImageNet-O
→ Better performance against PGD, FGSM, SPSA adversarial attacks
→ 1 PPL improvement on WikiText-103 language modeling
Share this post