Browse all previoiusly published AI Tutorials here.
Table of Contents
Introduction
Causes of Training Instability
So in more detail the Challenge with FP16
Learning Rate Strategies for Stability
Gradient Clipping
Optimizer Innovations (Adam, LAMB, Adafactor, etc.)
Initialization and Normalization Improvements
Stabilizing Reinforcement Learning-Based Training
Framework-Specific Practices (PyTorch vs JAX vs TensorFlow)
Hardware-Induced Instability and Distributed Training Challenges
Recent Research and Industry Insights
Introduction
Training large language models (LLMs) can be unstable – losses may spike and models sometimes diverge, wasting vast computational resources (HERE). Instability arises in all training paradigms (unsupervised pre-training, supervised fine-tuning, and reinforcement learning-based tuning) and across model scales. Causes include excessive learning rates, exploding gradients, poor initialization, and numerical issues in large-scale distributed setups. For example, loss spikes during transformer pre-training can degrade performance or even ruin a run . Ensuring stable convergence is therefore critical for both efficiency and model quality. This report reviews techniques to stabilize LLM training, including learning rate strategies, gradient clipping, optimizer innovations, initialization and normalization improvements, and hardware-specific best practices. We compare findings across PyTorch, JAX, and TensorFlow, highlight recent research (2024–2025), and note industry insights from major AI framework teams.
Causes of Training Instability
Gradient Explosions and Loss Spikes: A major source of instability is the sudden explosion of gradient norms, which manifests as sharp loss spikes (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training) . These often occur unpredictably in LLM training, causing the model to overshoot and diverge. Studies have observed that such spikes can reach magnitudes far beyond typical gradients, severely degrading model performance . In transformers, certain layers contribute disproportionately – e.g. extremely large updates in attention or feedforward layers can cause output norms to grow each step, eventually blowing up the loss ( Methods of improving LLM training stability).
Inadequate Normalization: Deep transformer networks are prone to instability if not properly normalized. Using the original Transformer’s configuration (post-layernorm) can lead to vanishing or exploding updates when many layers are stacked. Pre-layernorm designs (applying LayerNorm before each sub-layer) became standard because they are more stable for very deep models (HERE). Recent analyses found that even with Pre-LN, output norms of certain sub-layers (QKV projections, output projections, feed-forward layers) can grow without bound under high learning rates , indicating that existing normalization may be insufficient at extreme scales.
Learning Rate and Batch Size Issues: Aggressive training settings – large batch sizes and high learning rates – often trigger divergence. There is a known stability–efficiency dilemma: increasing batch size or LR improves throughput but can destabilize training (HERE) . Large batches reduce the noise in gradient estimates, which can make optimization less robust unless the learning rate is carefully adjusted. If the effective update is too large (product of LR and grad), the model can overshoot minima and diverge. Empirically, transformers have a narrow stable LR range that shrinks as model size and batch size grow, requiring schedules or scaling rules to avoid blow-up.
Reinforcement Learning Feedback Loops: When training LLMs with reinforcement learning (e.g. RLHF – Reinforcement Learning from Human Feedback), instability can arise from the model policy drifting too far from the data distribution. Without constraints, an RL-tuned policy might exploit the reward model in unintended ways, producing incoherent outputs (essentially “gibberish” that tricks the reward) (Illustrating Reinforcement Learning from Human Feedback (RLHF)). The policy optimization (e.g. via PPO) needs to be carefully regularized to prevent the model from diverging from its pre-trained behaviors. We discuss solutions like KL penalties and trust-region updates in a later section.
Numerical Precision and Hardware: Mixed-precision training (FP16/BF16) is ubiquitous for LLMs but introduces numerical challenges. FP16 has a limited dynamic range (max ~6.5e4), so activations or gradients that would be representable in FP32 can overflow to infinity in FP16, leading to NaNs (PyTorchのAMPはbf16を使え.多分nanが出なくなる. #AutomaticMixedPrecision - Qiita).
So in more detail the Challenge with FP16:
Limited Dynamic Range: FP16 can only represent numbers up to roughly 6.5 × 10⁴ (about 65,000). In contrast, FP32 can handle much larger values.
Overflow Issue: During training, neural networks compute activations (intermediate outputs) and gradients (used to update weights). Sometimes these values are large. In FP32 they can be accurately represented, but if they exceed the FP16 maximum, they "overflow." In floating-point arithmetic, an overflow typically turns a value into “infinity.”
Propagation to NaNs: Once a value becomes infinity, further arithmetic operations can produce “NaN” (Not a Number), which effectively breaks the training process.
This is exacerbated in very deep networks and with large activations. Distributed training across many GPUs/TPUs can also cause instability if updates are not synchronized properly. (Fortunately, most LLM training uses synchronous all-reduce updates; asynchronous methods tend to be less stable without special tuning (Data Parallelism in Machine Learning Training | by Soonmo Seong).)
Learning Rate Strategies for Stability
Careful management of the learning rate is one of the most effective ways to prevent divergence. Learning Rate Warmup – starting with a very small LR and gradually increasing at the start – is standard in nearly all LLM training recipes to avoid early instabilities. A too-large initial LR can cause a burst of large gradients before the model has calibrated its weights. Warmup (often over a few hundred or thousand steps) allows the model to start in a stable regime.
Lower Base Learning Rates: Large models typically require lower learning rates to remain stable. For instance, the original GPT-3 (175B) used a peak LR of only 2.8e-5 in its schedule (HERE) (very low compared to smaller models) to avoid divergence. Recent research has shown it’s possible to push LRs higher with other interventions (like better normalization, as discussed below), but without such measures, one must often sacrifice speed (lower LR) for stability.
Adaptive Schedules: Learning rate decay schedules (cosine decay, linear decay after warmup, etc.) also help by reducing the step size as training progresses, which can prevent late-training instabilities. Toward the end of training, when the model is near convergence, a high LR can knock it off course, so decaying it tends to stabilize the final phase.
Sequence Length Warmup: An innovative strategy is to treat sequence length as part of the schedule. Li et al. (2022) introduced Sequence Length Warmup (SLW), starting training with shorter sequences and gradually increasing to full length (HERE) . This addresses an observed cause of instability: long sequences early in training contribute to extreme gradient variance, leading to loss spikes . By beginning with short contexts, the model learns more stably and later handles long sequences once weights are better initialized. This method enabled dramatically larger training steps: in experiments replicating GPT-2 and GPT-3, SLW allowed using 8× larger batches and 4–40× higher learning rates without divergence, while maintaining equivalent or better accuracy . Essentially, SLW functions like a form of curriculum learning that improves stability-efficiency tradeoff.
Dynamic LR Reduction on Instability: In practice, some training runs include heuristic checks for instability (e.g. if loss suddenly spikes or gradients explode) and automatically reduce the learning rate or even reset to a previous checkpoint. This is a last resort “safety net.” For example, Google’s PaLM training noted manual intervention: when rare loss spikes occurred, they restarted from a recent checkpoint and skipped the problematic batch (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training). Modern optimizers (discussed next) are attempting to handle such events more gracefully without manual resets.
Gradient Clipping
Gradient clipping is a straightforward yet powerful technique to prevent exploding gradients. In LLM training, it’s common to clip the gradient norm to a fixed threshold (e.g. 1.0 or 0.5) each step. This means if the gradients collectively exceed the norm limit, they are scaled down to that limit. Clipping constrains the magnitude of weight updates, providing an upper bound on how big a single step can be – crucial for avoiding the runaway updates that lead to NaNs or divergence (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training).
For instance, the original Transformer paper used gradient clipping to stabilize training of their sequence models. Many large-scale models (BERT, T5, etc.) also employ clipping by global norm. Without clipping, a single unusually difficult batch could produce a huge gradient that wrecks the model’s parameters. Clipping ensures outlier gradients don’t overwhelm the training process . It can be seen as enforcing a trust region on the update. The downside is if the threshold is too low, it may slow learning by under-updating; so it’s about choosing a reasonable clip value that only affects rare spikes.
In practice, all major frameworks support easy gradient clipping. In PyTorch, for example, torch.nn.utils.clip_grad_norm_
is used between loss.backward()
and the optimizer step. TensorFlow’s optimizers can be wrapped to clip gradients by value or norm. Clipping is essentially always recommended when training very deep networks or with high LRs. The SPAM optimizer (next section) even includes a form of adaptive clipping targeted at spiking gradients (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training).
Optimizer Innovations (Adam, LAMB, Adafactor, etc.)
The choice of optimizer significantly impacts training stability. Adam (adaptive moment estimation) became the default for language models due to its stability advantages over plain SGD. Adam’s adaptive learning rates per parameter help handle differing gradient scales, and its momentum terms can smooth out noise. However, even Adam can accumulate problematic momentum from gradient spikes. Researchers have developed optimizer modifications to further improve stability for LLMs:
Adaptive Moment with Reset (SPAM): Huang et al. (2025) introduced Spike-Aware Adam with Momentum Reset (SPAM), designed specifically to combat gradient spikes in LLM training (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training) . When a large gradient spike is detected, SPAM resets the optimizer’s first and second moment estimates (essentially clearing momentum) to prevent the spike from continuing to influence future updates . It also uses spike-aware gradient clipping – identifying and down-scaling only the spiked gradients (instead of clipping everything) to preserve as much useful signal as possible . In experiments on both pre-training and fine-tuning, SPAM outperformed vanilla Adam and memory-efficient variants, leading to fewer instabilities and better final accuracy . Notably, by preventing loss spikes, SPAM reduced the need for manual restarts and even enabled using sparser momentum states to save memory (only keeping momentum for a subset of parameters) .
Large Batch Optimizers (LAMB/LARS): To scale batch sizes into the tens of thousands without divergence, new optimizers like LAMB (Layer-wise Adaptive Moments for Batch Training) were created. LAMB (You et al., 2019) builds on Adam but adapts the learning rate for each layer based on its weight norm, which proved crucial for stability in BERT pre-training with huge batches ( Large Batch Optimization for Deep Learning: Training BERT in 76 minutes). With LAMB, the authors successfully trained BERT with a batch size of 32k without loss of convergence or final accuracy – something that plain Adam would fail at (it would diverge or underfit). The layer-wise LR scaling ensures no layer gets a disproportionately large update in a single step. Similarly, LARS was used for large-batch ImageNet training. These optimizers show that scaling up batch size (and hence, using a larger LR) can be done stably if the update magnitude is properly controlled per layer.
Adafactor: For very large models, memory usage of Adam can be problematic (Adam keeps two momentum tensors the size of parameters). Adafactor (Shazeer et al., 2018) is an adaptive optimizer used in T5 that reduces memory by not storing the full second moment for all parameters (factored approximation). A side benefit observed is that Adafactor with appropriate learning rate schedule can be quite stable for large transformer training, possibly due to its relative update scale handling. It’s been used to train 11B+ parameter models reliably.
SGD with Momentum: Historically, SGD was tough to use for training transformers from scratch – it often required a very low learning rate and was slow to converge, though it is stable in the sense of not diverging easily. Adam’s adaptation is usually preferred for faster convergence. That said, for fine-tuning an already pre-trained model on a supervised task, sometimes a simple optimizer (SGD or AdamW with a low LR) is stable and effective, since the model is mostly at a good solution already.
In summary, most LLM training runs use Adam or a close variant. The latest research is addressing Adam’s weaknesses in handling rare large gradients (SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training). By resetting momentum and selectively clipping spikes, methods like SPAM keep training on track without sacrificing the benefits of Adam. Furthermore, layerwise-adaptive methods (LAMB) enable stable training at batch scales that naive Adam would not reach. Choosing an optimizer thus goes hand-in-hand with the desired batch size, learning rate, and memory constraints, all with the goal of keeping training smooth.
Initialization and Normalization Improvements
Stability can often be improved by how we initialize weights and design normalization in the model architecture:
Smaller Initial Weights: Large initial weights can lead to layer outputs of high variance, which in turn produce very large gradients at the start of training. Recent theoretical work by Takase et al. (2024) found that to avoid loss spikes, it’s important to initialize sub-layers with small parameter scales (HERE) . By reducing the standard deviation of initial weight distributions, the model’s Jacobian (gradient) norms are bounded more tightly, preventing early explosions. Many LLM implementations already use scaled Xavier or Gaussian inits (e.g. Megatron-LM’s initialization) that yield small activations . Ensuring those best practices (no accidentally large init) is key. Techniques like FixUp initialization or Zero-init residuals have also been proposed to stabilize very deep networks without normalization, by careful weight scaling.
Residual Structure and DeepNorm: Transformers rely on residual connections (adding the input to the output of sub-layers) – if these are not balanced, instability can occur especially in very deep models. Microsoft’s DeepNorm (Wang et al., 2022) introduced a modified residual scaling to enable stable training of Transformers up to 1000 layers. Instead of outputting x+f(x)x+f(x) at a residual block, DeepNorm outputs a⋅x+f(x)a⋅x+f(x), effectively up-weighting the residual (skip) connection by a factor
a
(and adjusting initial weights accordingly) (DeepNorm Allows Transformers to Accommodate More Layers). This keeps the magnitude of updates more constant across layers, preventing the diminishing or exploding update problem that standard LayerNorm setups can cause . With DeepNorm, extremely deep Transformers no longer diverged – in trials, models of 200 layers that normally diverged during training converged successfully with this change . The intuition is that a stronger skip connection (and/or smaller sub-layer weights) means each layer only perturbs the representation a little, so gradients don’t blow up as they backprop through many layers. This method yielded improved results and stable training where previous architectures failed .Additional Layer Normalization: As noted, even Pre-LN Transformers can see certain components’ outputs grow. Rybakov et al. (2024) showed that applying extra layer normalization after specific Transformer sub-layers can rein in exploding norms ( Methods of improving LLM training stability). In their experiments, adding LayerNorm after the attention output projection and after the feed-forward network (in addition to the normal places) stopped the unbounded growth of those outputs under high learning rate . This allowed them to raise the learning rate 1.5× higher than baseline without divergence, and it improved final perplexity as well . Essentially, they introduced more normalization points to keep activations in check. Another trick they explored was “softmax capping” – limiting the magnitude of the QK attention logits before softmax (clipping extreme values) . That also helped stabilize attention. These kinds of adjustments show that sometimes the network architecture itself can be tweaked to be more robust to large gradients.
Mixing Normalization Schemes: Some recent papers explore hybrid normalization (e.g. a Mix of pre- and post-layernorm, or combining RMSNorm with LayerNorm) to get the best of both worlds in stability and performance ( arXiv:2412.13795v1 [cs.LG] 18 Dec 2024). For instance, Mix-LN (Wei et al., 2024) proposes a combination that avoids instabilities seen in purely post-LN Transformers while improving deep layer quality ( arXiv:2412.13795v1 [cs.LG] 18 Dec 2024). While details vary, the general goal is to ensure normalization keeps activations well-behaved at all depths.
In summary, a stable LLM training run often begins with a well-initialized model (no layer starts with too large weights or outputs) and an architecture calibrated for stability (using Pre-LayerNorm, possibly additional normalization like DeepNorm or Rybakov’s method for extreme cases). By controlling the scale of activations from the start, we make it much less likely for gradients to ever reach catastrophic levels.
Stabilizing Reinforcement Learning-Based Training
Training LLMs via reinforcement learning – for example, optimizing a policy to maximize a reward model as in RLHF – introduces unique stability challenges. The objective is no longer a fixed dataset; the model’s own outputs influence the training data, creating a feedback loop that can spiral out of control if not checked.
KL-Divergence Regularization: A common solution, used by OpenAI, Anthropic, and others in RLHF, is to add a penalty term to the reward that measures how far the new policy distribution has strayed from the original model. In practice this is implemented as a KL-divergence term: the policy’s output distribution is penalized for diverging from the pre-trained model’s distribution (Illustrating Reinforcement Learning from Human Feedback (RLHF)). By subtracting a term rKL=β⋅KL(πnew∣∣πold)rKL=β⋅KL(πnew∣∣πold) from the reward, the agent is incentivized to stay close to the pre-trained behavior while still improving reward. This greatly stabilizes training – it prevents the policy from going off into nonsensical regions of output space that would “fool” the reward model but be useless or incoherent as language . Essentially, KL regularization acts like a leash, keeping the RL update steps moderate.
Trust-Region Policy Optimization (PPO): The PPO algorithm used in many RLHF implementations inherently includes stability measures. PPO limits how much the policy can change in a single update via a clipped objective – it’s a trust-region method that says “don’t move too far from the current policy.” This helps avoid instability because it ensures each policy update is a small adjustment, rather than a wild jump . In effect, PPO with a KL penalty means the policy will evolve slowly and smoothly, which is crucial when fine-tuning large language models with RL. Other algorithms like DeepMind’s use of A2C for Gopher’s alignment also report needing careful constraints to maintain stability .
Reward Model Accuracy vs Stability: Interestingly, very “sharp” reward models (that give extremely high rewards to very specific behaviors) can induce instability. If the reward model is too stringent or overfit, the policy might chase some narrow trajectory and collapse diversity. Recent analyses (e.g. OpenAI’s InstructGPT paper) found that mixing in some supervised learning signal or using a moderate reward model (not over-optimized) led to more stable training. In essence, some entropy or exploration helps avoid mode-collapse of the policy.
Curriculum in RLHF: Just as in supervised training, one can curriculum the RL process. For instance, starting with a smaller KL penalty (allowing more exploration) and then tightening it gradually can let the policy learn faster early on but then rein it in to ensure final answers stay reasonable. This must be tuned carefully to avoid instability early on, though.
In summary, stabilizing RL-based training of LLMs hinges on not straying too far from known-good behavior. Penalties like KL keep the model’s updates grounded (Illustrating Reinforcement Learning from Human Feedback (RLHF)), and algorithms like PPO inherently cap the update size . These, combined with careful reward model design, have enabled large-scale RLHF (e.g. training GPT-4 with human feedback) to succeed without the model collapsing or diverging. Without these measures, an RL-trained LLM could easily produce degenerate outputs to game the reward function, so these techniques are now standard practice in the field.
Framework-Specific Practices (PyTorch vs JAX vs TensorFlow)
All major deep learning frameworks provide tools to mitigate training instability, but there are some differences in defaults and best practices:
PyTorch: Researchers using PyTorch often leverage automatic mixed precision (AMP) with dynamic loss scaling to prevent FP16 underflow/overflow. PyTorch’s AMP has multiple optimization levels (O0 = FP32, O1 = conservative mixed precision, O2 = nearly full FP16). Empirically, Meta/Facebook found that AMP level O2 with BFloat16 or keeping LayerNorm in FP32 achieves the best throughput without sacrificing stability Scaling Vision Model Training Platforms with PyTorch | PyTorch Scaling Vision Model Training Platforms with PyTorch | PyTorch. In fact, they note full FP16 training (everything in half precision) was challenging to converge due to numeric issues Scaling Vision Model Training Platforms with PyTorch | PyTorch, but using a mixed mode where LayerNorm (and other sensitive ops) run in FP32 avoided those instabilities Scaling Vision Model Training Platforms with PyTorch | PyTorch. PyTorch also offers the Fully Sharded Data Parallel (FSDP) for distributed training, which has internal checks for NaNs and can automatically free/reload weights to handle huge models. The PyTorch team’s blogs emphasize careful sharding and accumulation to not blow past memory limits (which can indirectly cause instabilities if OS starts swapping, etc.) Scaling Vision Model Training Platforms with PyTorch | PyTorch Scaling Vision Model Training Platforms with PyTorch | PyTorch. Gradient clipping in PyTorch is manual (the trainer must call the function), but many high-level libraries (e.g. PyTorch Lightning, Hugging Face Trainer) integrate it as a one-line setting.
TensorFlow (Keras): TensorFlow’s
tf.keras.mixed_precision
API makes it easy to train in mixed precision and by default will apply dynamic loss scaling to combat underflow. This means it will automatically increase the loss scale until gradients are in a good range and decrease it if infinities are encountered (PyTorchのAMPはbf16を使え.多分nanが出なくなる. #AutomaticMixedPrecision - Qiita). This mostly hides the manual work, but as some user reports note, it’s not foolproof – extreme cases might still produce NaNs (Mixed Precision Questions - Faceswap Forum). In such cases, the TensorFlow documentation suggests trying bfloat16 (on TPUs) or reducing the global learning rate. TensorFlow also tends to favor the use of bfloat16 on TPUs. Bfloat16 has a larger exponent range (same as FP32) so it dramatically reduces overflow issues . Google’s TPU-based LLM trainings (T5, PaLM, etc.) predominantly use bfloat16 for this reason – it gives the speed/memory benefits of 16-bit precision with far better numerical stability (often no need for loss scaling at all). Many of the instability problems seen with FP16 on GPUs disappear or lessen with bfloat16 . With GPUs now supporting bfloat16, frameworks like PyTorch and JAX also allow its use there. In short, TensorFlow’s guidance is usually: use mixed precision + loss scaling, prefer bfloat16 if available for stability, and clip gradients if needed (TensorFlow optimizers have clipping parameters or one can manually clip in the training loop).JAX/Flax: JAX excels on TPUs, so bfloat16 is the norm in that ecosystem. Projects like the EleutherAI and BigScience models trained with Flax (JAX) reported very few stability issues when using bfloat16 – for instance, the 176B-parameter BLOOM model training (on TPU-v4 pods via JAX) found bfloat16 training to be stable whereas earlier GPU efforts in FP16 ran into more frequent NaNs . JAX’s functional paradigm means the user controls things like update ordering explicitly, so one must be careful to e.g. apply gradient clipping functionally before the update. JAX also has a bit less black-box automation – e.g., you explicitly decide on loss scaling implementation if using FP16 on GPUs. However, its compilation to XLA can sometimes perform operations in higher precision internally. One notable difference: some JAX users observed that enabling deterministic algorithms (to debug reproducibility) can avoid certain races or non-deterministic summation orders that might lead to one replica diverging. Generally, though, the same principles apply: use adaptive optimizers, clipping, and proper init.
In summary, PyTorch and TensorFlow have largely converged on mixed-precision training as the standard, with built-in measures to maintain stability (dynamic loss scaling, etc.). PyTorch gives a bit more manual control (and responsibility) to the researcher, whereas TensorFlow’s Keras API automates more of it. JAX, living in the TPU world, leans on bfloat16 which sidesteps many FP16 pitfalls (PyTorchのAMPはbf16を使え.多分nanが出なくなる. #AutomaticMixedPrecision - Qiita). All frameworks provide solutions for distributed training (DDP, Strategy, pjit, etc.) that, if used correctly, do not sacrifice stability versus single-GPU training. The key is that regardless of framework, one should abide by the best practices: enable mixed precision wisely, watch out for any NaNs/infs (and react by lowering LR or switching precision), and use the clipping and regularization tools available.
Hardware-Induced Instability and Distributed Training Challenges
Mixed-Precision Arithmetic: Modern accelerators often use 16-bit floats to speed up training. As discussed, FP16 (half precision) is prone to overflow/underflow, which can introduce instability. Overflow typically shows up as Inf or NaN values in activations or gradients, which, if unchecked, will propagate and break training. Underflow (very small values flushed to zero) can silently dampen gradients. The solution is twofold: (1) use loss scaling – multiply the loss (and hence gradients) by a scale factor to keep them in representable range, then divide the updates by the scale. This is done dynamically to adapt to the range of gradients (PyTorchのAMPはbf16を使え.多分nanが出なくなる. #AutomaticMixedPrecision - Qiita). (2) Use BFloat16 where possible – BF16 has 8 exponent bits (like FP32) vs FP16’s 5 bits, so it can represent extremely large and small values without overflow in most training scenarios . BF16 essentially eliminates the Inf/NaN overflow issue in practice, at the cost of slightly lower precision in mantissa. NVIDIA’s newer GPUs support BF16 in PyTorch and TensorFlow; enabling it often improves stability notably (many report “no more NaNs” when switching from FP16 to BF16).
Even with loss scaling, one must monitor for any NaN occurrence. Typically, if a NaN gradient is detected, frameworks will either stop or skip that batch. It’s good to log these events – a single NaN infrequently might be recoverable, but frequent ones indicate a systemic issue (too high LR or unstable configuration).
Distributed Training Synchronicity: In large-scale training, the model is split across many devices. Almost always this is done with synchronous updates (all gradients are averaged each step). Synchronous data parallelism is stable as long as every worker sees a similar data distribution; if one worker encounters a problematic batch that causes a spike, the average gradient will still be affected. There’s no easy fix there except techniques already discussed (clip the gradient on that worker, etc.). Asynchronous training (where workers update a parameter server at different times) is less common for LLMs now, because while it can improve hardware utilization, it tends to be harder to converge. Asynchronous updates introduce stale gradients which can act like a too-large effective learning rate. Research in the 2010s on deep nets found that async SGD could diverge if the learning rate isn’t low enough to account for staleness. Thus, most LLM pipelines stick to synchronous all-reduce, which “feels” like single-GPU SGD (just with a larger batch). This is more predictable and empirically more stable (Data Parallelism in Machine Learning Training | by Soonmo Seong).
All-Reduce and Numerical Sum Order: A subtle hardware issue: when summing gradients from, say, 16 GPUs, the order of summation can slightly change the floating-point result (due to finite precision). In theory, this could cause tiny differences between runs that break bitwise reproducibility. In practice, it shouldn’t cause divergence (just minuscule noise), but for paranoia, frameworks can enforce deterministic reductions (at some performance cost). Some HPC teams ensure the reduction tree is balanced to minimize rounding error.
Memory and Compute Errors: Occasionally, non-deterministic hardware errors (e.g. bit flips from cosmic rays or unstable overclocking) could cause a weight to become NaN. These are rare, but in ultra-long training (trillions of operations) it’s not impossible. ECC memory and stability testing mitigate this. It’s worth noting if one GPU in a cluster has a faulty module, it could repeatedly cause errors that manifest as training instability – so hardware health is a consideration.
Distributed Optimizer States: In sharded training, optimizer states (like momentum) might be split across processes. It’s important that updates to those states are applied consistently. Tools like DeepSpeed and FairScale handle this. A bug in state synchronization could definitely cause one shard’s weights to diverge from the rest. Thus, correctness in distributed optimizer implementation is critical (and has improved greatly in recent releases of these libraries).
Gradient Accumulation vs Large Batch: An alternative to scaling batch with more GPUs is gradient accumulation on fewer GPUs (simulate a large batch by accumulating gradients over multiple forward passes before updating). This is generally stable as it’s mathematically equivalent to a larger batch (just splits it over time). However, one must then be cautious with learning rate schedules (adjusting for effective batch) and ensure no drift occurs between accumulation steps (e.g. if using mixed precision, one should update the loss scale across the whole accumulated batch, not per mini-micro-batch).
In summary, hardware and distribution bring potential instability mostly through numerical precision issues and synchronization complexity. By using appropriate precision formats (BF16 where possible, or FP16 with robust loss scaling) and sticking to synchronous training, these issues can be managed. Industry practitioners (e.g. NVIDIA’s guides) strongly recommend monitoring for overflow/underflow warnings and using provided utilities to combat them (PyTorchのAMPはbf16を使え.多分nanが出なくなる. #AutomaticMixedPrecision - Qiita). When everything is set up correctly, one can reliably train multi-billion-parameter models on hundreds of GPUs without divergence – something that was a major challenge just a few years ago.
Recent Research and Industry Insights
Stabilizing LLM training is an active area of research, with 2024-2025 seeing several notable advances. We have already touched on some (SPAM optimizer, Rybakov’s LayerNorm tweaks, sequence length warmup). To contextualize: much of this research is driven by the enormous cost of unstable runs. A single divergence in a 100B+ model run could mean weeks of lost time, so both academia and industry are highly motivated to find stability tricks.
Some highlights include:
“Spike No More” (Takase et al., 2024): A theoretical analysis of loss spikes in transformer pre-training (HERE) . They derive conditions for stability – specifically that sub-layers should be “small” (in parameter norm) and residual connections “large” (dominant) . These conditions echo what DeepNorm and good init practices achieve. Their work helped explain why techniques like scaling embedding layers and using small initializations were empirically successful. Notably, they demonstrate that methods satisfying these conditions allow training with larger learning rates and lead to better outcomes .
Normalization Innovations: Aside from DeepNorm and Mix-LN, another 2024 idea is BranchNorm (Zhang et al., 2024) which normalizes across residual branches to ensure balanced contributions. While details get complex, the trend is exploring new normalization schemes to enable deeper or more aggressive training without divergence.
Industry Scale Experiments: Meta AI’s LLaMA models and Google’s PaLM both provided insight into stability at the extreme scale. PaLM (540B) was trained with a few manual restarts due to instability; they attributed some spikes to rare tokens or data issues (e.g. extremely long or degenerate text sequences). This has pushed companies to pre-screen training data to remove pathologically difficult samples or outliers that could trigger instabilities. On the software side, Google’s Tensorflow team standardized using BF16 on TPUs given its clear stability edge, and Meta’s PyTorch-based pipeline for LLaMA-65B used Zero Redundancy Optimizer (ZeRO) sharding plus frequent gradient checkpointing to handle memory and stability – gradient checkpointing (recomputing activations instead of storing) can increase numerical differences slightly, but they found it did not harm stability in practice.
Official Framework Blogs: The PyTorch team’s blogs on scaling models highlight numerical stability as a key challenge and explicitly mention that training became “severe and hard to deal with when we scale up model sizes, data, batch sizes, learning rate, etc.” Scaling Vision Model Training Platforms with PyTorch | PyTorch. They had difficulty training even a 630M ViT at FP16 without extra augmentation, underscoring how trickier stability becomes at scale Scaling Vision Model Training Platforms with PyTorch | PyTorch. Their solution involved using bfloat16 and ensuring certain layers (like LayerNorm) remained in high precision Scaling Vision Model Training Platforms with PyTorch | PyTorch Scaling Vision Model Training Platforms with PyTorch | PyTorch. The TensorFlow team, via the Cloud TPU blogs, heavily promotes mixed precision with BF16, demonstrating cases where models trained faster and no loss spikes occurred once BF16 was used instead of FP16. NVIDIA’s research blogs on large batch training emphasize using LAMB optimizer to avoid instabilities at high batch size ( Large Batch Optimization for Deep Learning: Training BERT in 76 minutes). NVIDIA also released Megatron-LM library, which incorporates many of these best practices (e.g., sophisticated LR schedules, gradient clipping, and distributed training tricks) so that researchers can train multi-billion parameter models more reliably.
Community Knowledge: It’s worth noting that a lot of “tricks” get spread in the community (through blogs, forums, repos) outside of formal papers. For example, the Hugging Face forum and Reddit have countless discussions on issues like “my 6B model training loss exploded, what to do?” and the answers often compile the wisdom we’ve described: reduce your LR, use grad clipping, try BF16, check for NaNs, etc. Over time, these have coalesced into default settings in libraries (many Hugging Face Trainer defaults include LR warmup, grad clipping at 1.0, etc., precisely to protect against instability).
In conclusion, the landscape of LLM training stability is one of continuous improvement. By combining theoretical insights from recent papers with the hard-earned lessons of industry-scale training, the community has made training runs far more robust than before. Techniques like learning rate warmups, careful normalization, and optimizer tuning are now standard arsenal. As models and data scales push ever higher (trillion-parameter models, multi-trillion token datasets), ensuring stability will remain a moving target – but the work surveyed here provides a strong foundation, and new research in 2024–2025 is actively extending it. With these advances, we can train larger and better LLMs with fewer costly interruptions due to divergence (HERE), making the pursuit of ever more capable models more feasible and efficient than ever before.