"EfficientLLM: Scalable Pruning-Aware Pretraining for Architecture-Agnostic Edge Language Models"
Below podcast on this paper is generated with Google's Illuminate.
https://arxiv.org/abs/2502.06663
The increasing cloud costs, latency and privacy concerns necessitate efficient edge LLMs, but direct pretraining of smaller models struggles to match larger model performance. This paper addresses this by introducing pruning-aware pretraining.
This method scales up LLM compression techniques into the pretraining phase. It retains the performance of larger models while creating efficient smaller ones.
-----
📌 Pruning-aware pretraining effectively shifts compute from data to model optimization. It leverages larger model knowledge within smaller parameter budgets, surpassing standard scaling laws for edge LLMs.
📌 This method introduces an adaptive architecture search during pretraining. Saliency-driven pruning dynamically shapes the network, finding efficient configurations beyond fixed human-designed architectures.
📌 EfficientLLM demonstrates that scaling pruning during pretraining, not just post-training, is crucial. This unlocks higher compression ratios without significant performance drops for edge deployment.
----------
Methods Explored in this Paper 🔧:
→ Introduces "pruning-aware pretraining". This method integrates structural pruning into the LLM pretraining process itself.
→ It starts with a larger, pretrained LLM and progressively prunes less important parameters throughout pretraining. This is different from post-training pruning which is done after pretraining is complete.
→ The core idea is to identify and remove "minimal parameter groups" during pretraining. These groups are sets of interconnected parameters within the Transformer architecture. Three types of pruning are explored: per-head attention pruning, per-channel FFN pruning and transformer stem pruning.
→ Saliency detection guides the pruning process. The paper uses Taylor expansion to approximate the loss change caused by removing parameter groups. Groups with lower saliency are prioritized for pruning.
→ The method is "architecture-agnostic". The model's architecture is not fixed beforehand. Instead, pruning dynamically shapes the architecture during pretraining based on saliency.
→ Second-order weight updating is employed. This technique refines the remaining weights after pruning to minimize performance loss. It decouples Hessian approximation for saliency detection and weight updates for efficiency.
-----
Key Insights 💡:
→ Scaling up LLM compression to the pretraining stage bridges the performance gap between direct pretraining of small LLMs and post-training compression of large LLMs.
→ Pruning-aware pretraining allows for creating high-performance edge LLMs that exceed the traditional scaling law limitations for smaller models.
→ Saliency-driven, auto-designed architectures achieved through pruning-aware pretraining can be competitive with or even surpass human-designed architectures for edge LLMs.
→ By continuously pruning during pretraining with large datasets, the method retains more of the original model's performance compared to post-training pruning or limited pruning during pretraining.
-----
Results 📊:
→ EfficientLLM outperforms baselines like MobileLLM, SmolLM, Qwen2.5-0.5B, OLMo-1B, and Llama3.2-1B on common sense benchmarks. For example, EfficientLLM-134M surpasses Pythia-410M by 4.13% average accuracy.
→ EfficientLLM-469M with 50B tokens exceeds SmolLM-360M which was trained with 600B tokens, in common sense reasoning tasks. This shows improved data efficiency.
→ EfficientLLM-1.1B with 50B tokens outperforms OLMo-1B, TinyLlama-1.1B, and Llama3.2-1B in accuracy, demonstrating strong performance with limited pretraining data.