Optimize Your Training Like You're Down Bad

Optimize Your Training Like You’re Down Bad


Introduction

If you’re serious about Stable Diffusion LoRA training, you’ve probably experienced the frustration of long training times, high VRAM usage, and limited batch sizes. This guide introduces three powerful optimizers and BF16 precision training techniques that can significantly speed up your training while reducing memory consumption.

Understanding BF16 Precision

BFloat16 (BF16) is a 16-bit floating-point format that uses the same number of exponent bits as FP32 but fewer mantissa bits. This gives it several advantages for deep learning:

  • Memory efficiency: Half the size of FP32, allowing larger batch sizes
  • Training speed: Faster computation with minimal precision loss
  • Dynamic range: Maintains the same range as FP32, unlike FP16

However, BF16 has a potential issue: stale gradients. Due to reduced precision, very small weight updates might not register during training. This is where specialized optimizers come in.

Specialized Optimizers for BF16 Training

SPARKLES Optimizer

SPARKLES (Stochastic Parameter Adjustment with Randomized Kick for Learning Enhancement Strategy) combines multiple advanced techniques for improved neural network training. SPARKLES builds upon the foundation laid by Lodestone’s Compass Optimizer¹, extending it with stochastic elements and adaptive mechanisms.

Key features:

  • Adaptive normalization: Normalizes gradients using their standard deviation
  • Stochastic update strategy: Helps escape local minima with randomized operations
  • Stochastic BF16 rounding: Enhances precision conversion with controlled randomness
# Example configuration in a training script
--optimizer_type=SPARKLES
--optimizer_args="centralization=0.5 normalization=0.5 amp_fac=2 stochastic_threshold=1e-6"

Torchastic

Torchastic is a stochastic BFloat16-based optimizer library that directly addresses the stale gradient problem in BF16 training.

Key features:

  • Stochastic rounding: Ensures small updates don’t get lost during long training
  • BF16-native: Designed specifically for BF16 precision, reducing memory usage by 50%
  • Gradient accumulation hooks: Maintains precision during gradient accumulation
# Example usage with Torchastic
import torch
import torch.nn as nn
from torchastic import AdamW, StochasticAccumulator

# Init model
model = Model(*model_args)
model.to(torch.bfloat16)
optimizer = AdamW(model.parameters(), lr=0.01, weight_decay=1e-2)

# Apply stochastic grad accumulator hooks
StochasticAccumulator.assign_hooks(model)

¹ Lodestone’s Compass Optimizer is a modification of AdamW with gradient centralization, adaptive step sizing, and momentum amplification features. SPARKLES extends these concepts with additional stochastic mechanisms and some other weirdness.

Enabling Full BF16 Training in Kohya’s sd-scripts

Kohya’s sd-scripts supports mixed precision training, but for maximum speed and memory efficiency, you can enable full BF16 training with these patches:

Patch 1: Enable BF16 for everything (Preliminary)

This patch enables BF16 precision for all operations including gradient computation.

Patch 2: Full BF16 Optimization

This patch completes the BF16 implementation by fixing all the stupid I farted in with Claude, ensuring all aspects of training use BF16 precision.

Application Instructions

  1. Clone Kohya’s sd-scripts:

    git clone https://github.com/kohya-ss/sd-scripts.git
    cd sd-scripts
    
  2. Apply the patches:

    git remote add wizard https://github.com/rakki194/sd-scripts.git
    git fetch wizard
    git cherry-pick f56e7d856ee03fbbc0892e483cc6d38f7190659f
    git cherry-pick 7de84161bd2e82da1e5de8f445713a0921ea819a
    
  3. Use in training with the --full_bf16 parameter:

    python train_network.py --full_bf16 [other parameters]
    

Benchmarks and Comparison

If you don’t believe fucking Meta why should I even try?

Hardware Requirements

For optimal BF16 training performance:

  • GPU: NVIDIA RTX 30 series or newer (Ampere architecture or later)
  • CUDA: Version 11.0 or later
  • NCCL: Version 2.10 or later (for distributed training)

For the best balance of speed, efficiency, and quality:

# Optimizer config
--optimizer_type=SPARKLES
--train_batch_size=14
--max_grad_norm=1
--gradient_checkpointing

# BF16 Precision optimization
--torch_compile
--dynamo_backend=inductor
--no_half_vae
--sdpa
--mixed_precision="bf16"
--full_bf16

# Additional memory optimizations
--cache_latents
--cache_latents_to_disk

Conclusion

Optimizing your Stable Diffusion LoRA training with specialized BF16 optimizers can dramatically reduce memory usage and training time. While this approach is more advanced and requires newer hardware, the benefits in training efficiency make it worth considering for serious LoRA creators.

Remember that these techniques are experimental and may not be suitable for all training scenarios. Always monitor your training progress carefully and be prepared to adjust parameters as needed.