"SANA 1.5: Efficient Scaling of Training-Time and Inference-Time Compute in Linear Diffusion Transformer"
Below podcast on this paper is generated with Google's Illuminate.
This paper introduces efficient scaling methods for text-to-image generation using linear Diffusion Transformers.
It tackles the increasing computational costs of larger models by proposing techniques for efficient training, model compression, and inference scaling, making high-quality image generation more accessible.
-----
https://arxiv.org/abs/2501.18427
1. Model growth via partial preservation initialization shows strong practical value. It avoids training large diffusion models from scratch. Reusing weights from smaller models accelerates convergence by 2.5x. This method stabilizes training by normalizing query and key in attention layers. Dropping the last two pre-trained blocks further enhances learning in new blocks. This approach efficiently scales model capacity while saving compute.
2. Block importance-based pruning offers a flexible deployment strategy. Analyzing input-output similarity identifies less critical blocks. Pruning these blocks compresses the 4.8B model down to 1.6B with minimal quality loss. Fine-tuning for just 100 steps recovers performance. This enables efficient deployment across diverse resource constraints. Pruning offers a practical way to balance model size and generation quality.
3. Inference-time scaling with VLM-guided selection significantly boosts performance. Generating 2048 samples and selecting the best using NVILA improves GenEval score to 0.80. This demonstrates compute can compensate for smaller models. VLM judges filter out prompt-mismatching images effectively. Tournament-style comparison ensures robust selection. This method challenges the assumption that larger models are always necessary for better quality.
-----
Methods in this Paper ✨:
→ Efficient Training Scaling is achieved through a depth-growth paradigm.
→ A smaller 1.6B parameter model is expanded to 4.8B parameters by strategically adding new blocks.
→ Partial Preservation Initialization preserves pre-trained layers and initializes new layers randomly.
→ This method reduces training time by 60% compared to training from scratch.
→ Memory-efficient CAME-8bit optimizer with 8-bit quantization further reduces memory usage by approximately 8 times compared to AdamW-32bit.
→ Model Depth Pruning technique analyzes block importance based on input-output similarity.
→ Less important blocks are pruned, and model quality is recovered with minimal fine-tuning.
→ Inference-time Scaling involves repeated sampling and VLM-based selection.
→ Generating multiple samples and selecting the best using a Visual Language Model improves quality.
→ This allows smaller models to achieve quality comparable to larger models by trading compute for capacity.
-----
Key Insights from this Paper 🤔:
→ Efficient scaling can be achieved through optimization strategies, not just by increasing model size.
→ Reusing knowledge from smaller models via model growth significantly reduces training costs.
→ Block importance analysis allows for effective model compression through depth pruning.
→ Inference-time scaling demonstrates that computational resources can substitute for model capacity to enhance generation quality.
→ Partial Preservation Initialization strategy provides stable training dynamics for model growth.
-----
Results 🚀:
→ Achieves a text-image alignment score of 0.72 on GenEval.
→ Inference scaling further improves GenEval score to 0.80, establishing a new state-of-the-art.
→ Reduces training time by 60% compared to training from scratch using model growth.
→ CAME-8bit optimizer reduces memory consumption by 25% compared to AdamW.