A Common Pitfall of Margin-based Language Model Alignment: Gradient Entanglement
When teaching AI right from wrong, both answers grow stronger together - this paper tries to ans why
When teaching AI right from wrong, both answers grow stronger together - this paper tries to ans why
When fine-tuning LLMs, gradients of good and bad responses get tangled, causing unexpected probability shifts.
Training LLMs to prefer good responses accidentally boosts bad ones due to shared token gradients.
Original Problem 🔍:
Margin-based language model alignment methods under-specify ideal behavior on chosen and rejected responses individually, often causing synchronized increases or decreases in probabilities of both.
Solution in this Paper 🛠️:
• Identifies "gradient entanglement" as the root cause
• Derives gradient inner product conditions for various margin-based algorithms
• Theoretically analyzes when gradient inner product becomes large
• Proposes two potential solutions:
Pairwise normalized gradient descent
Sparsity regularized token masking
Key Insights from this Paper 💡:
• Gradient entanglement couples changes in chosen and rejected probabilities
• Token-level gradient dynamics crucial for understanding entanglement
• Length-normalization and explicit regularization affect training dynamics
• Current margin-based paradigm may not suit all alignment scenarios
Results 📊:
• Empirically validates gradient entanglement across various algorithms
• Demonstrates synchronized probability changes in TL;DR dataset experiments
• Shows gradient cosine similarity patterns align with theoretical predictions
• Verifies token-level gradient correlations in sentiment analysis tasks
📌 The paper identifies a common pitfall in margin-based language model alignment methods: they under-specify the ideal behavior of the language model on chosen and rejected responses individually.
This often results in synchronized increases or decreases in the probabilities of both chosen and rejected responses, rather than increasing the probability of chosen responses while decreasing the probability of rejected ones.
📌 The underlying cause is an effect they term "gradient entanglement".
Margin-based losses couple the change in the chosen probability to the gradient of the rejected one, and vice versa, often preventing the chosen and rejected probabilities from changing independently. This entanglement is passed through the inner product between the gradients of the chosen and rejected log-probabilities.
📌 Gradient entanglement becomes concerning when the inner product between the gradient of the chosen log-probability and the gradient of the rejected log-probability is large relative to their individual gradient norms.
The authors derive specific conditions for different margin-based algorithms that characterize when this occurs.
The authors suggest two potential algorithm designs to mitigate gradient entanglement:
Pairwise normalized gradient descent - reweighing the chosen and rejected log-probabilities in the margin-based loss to ensure both parts of their gradient condition are satisfied simultaneously.
Sparsity regularized token masking - a fine-grained margin-based loss that only contrasts significant tokens, leveraging token-level information to reduce entanglement.