LINES: POST-TRAINING LAYER SCALING PREVENTS FORGETTING AND ENHANCES MODEL MERGING
Fine-tuning often leads to catastrophic forgetting. Where the model loses its ability to generalize across other tasks.
Fine-tuning often leads to catastrophic forgetting. Where the model loses its ability to generalize across other tasks
This paper proposes a smarter way to stop that by scaling neural network layer updates by depth to prevent forgetting while keeping performance.
Original Problem 🎯:
Current solutions to catastrophic forgetting require complex training modifications or computational overhead.
Solution in this Paper 🛠️:
• LiNeS (Layer-increasing Network Scaling) applies linear scaling to parameter updates based on layer depth
• Shallow layers get minimal updates to preserve general features
• Deep layers retain full updates for task-specific learning
• Scaling factor: λ(ℓ) = α + β(ℓ-1)/(L-1) for layer ℓ in L total layers
• Works as post-training technique - no additional training needed
• Requires only 1-2 hyperparameters (α, β)
Key Insights 🔍:
• Shallow layer updates minimally impact target task performance
• Preserving shallow layer weights maintains generalization ability
• Simple linear scaling achieves similar results as complex optimization
• Method works across vision, NLP, and LLM domains
Results 📊:
• Maintains 99.8% target task performance while preserving 97.9% pre-trained performance
• Improves multi-task merging baselines by 3.1-4.0%
• Enhances OOD generalization across 5 different distribution shifts
• Works with models from ViT-B/32 to LLaMA-7B scale
💡 The LiNeS method exploits the fact that shallow layers in neural networks capture general features while deeper layers contain task-specific representations.
By scaling down parameter updates in shallow layers post-training while preserving updates in deeper layers, LiNeS maintains broad generalization capabilities while retaining task-specific performance gains.
🔧 LiNeS layer scaling approach
Given a task vector τ (difference between fine-tuned and pre-trained weights) with L layers, LiNeS applies a linear scaling factor λ(ℓ) = α + β(ℓ-1)/(L-1) to each layer ℓ.
This progressively scales parameter updates from factor α for the first layer to α+β for the last layer, with intermediate layers scaled linearly based on depth.