0:00
/
0:00
Transcript

"Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models"

The podcast on this paper is generated with Google's Illuminate.

Two-step image generation achieves same quality as hundred-step diffusion models with 90% less compute.

Model achieves 0.11 second image generation while maintaining diffusion-level quality standards.

Stabilizes continuous-time consistency models

📚 https://arxiv.org/pdf/2410.11081

Original Problem 🔍:

Continuous-time consistency models offer fast sampling but suffer from training instability issues. Existing discrete-time solutions introduce discretization errors and require careful timestep scheduling.

💡 A consistency model is a neural network trained to map noisy inputs directly to clean data in one or few steps, operating as a distillation of diffusion models that achieves similar quality while eliminating the need for numerous denoising iterations.

-----

Solution in this Paper ⚡:

- Introduces TrigFlow framework unifying Elucidating Diffusion Models (EDM) and flow matching parameterizations

- Implements Jacobian-vector product (JVP) rearrangement for numerical stability in FP16 training

- Employs tangent normalization to control gradient variance

- Uses adaptive double normalization in network architecture

- Adds progressive annealing and adaptive weighting in training objectives

- Develops Flash Attention with JVP support for memory efficiency

-----

Key Insights from this Paper 💡:

- Continuous-time models outperform discrete-time versions across all settings

- Discrete-time quality degrades after N>1024 steps due to numerical issues

- Consistency distillation scales more predictably than consistency training

- Training stability improves with focused sampling near clean data distribution

-----

Results 📊:

- Achieves FID 2.06 on CIFAR-10, 1.48 on ImageNet 64x64, 1.88 on ImageNet 512x512

- Scales to 1.5B parameters on ImageNet 512x512

- Narrows FID gap with best diffusion models to within 10% using only 2 sampling steps

- Reduces effective sampling compute by 90% compared to standard diffusion models

Discussion about this video