Add Custom Optimizers to sd-scripts for LoRA Training

Add Custom Optimizers to sd-scripts for LoRA Training


Introduction

Optimizers play a crucial role in training Stable Diffusion LoRA models effectively. While standard optimizers like AdamW work well in many scenarios, custom optimizers can significantly improve training convergence, reduce training time, and enhance the quality of your LoRA models.

The sd-scripts library by Kohya, one of the most popular toolkits for LoRA training, supports several optimization algorithms out of the box. However, you may want to experiment with newer or more specialized optimizers that aren’t included by default. This guide will walk you through the process of adding custom optimizers to sd-scripts.

Why Use Custom Optimizers?

Different optimizers offer various benefits for LoRA training:

  • Faster convergence: Some optimizers like Prodigy or Lion can reduce training time
  • Better quality: Specialized optimizers may produce higher quality results for specific types of LoRA models
  • Reduced overfitting: Certain optimizers apply regularization techniques that help prevent overfitting
  • Adaptive learning rates: Optimizers like Adafactor dynamically adjust learning rates, reducing the need for manual tuning

Setting Up the Custom Optimizers Directory

First, let’s create a new optimizers folder in the library directory and an empty __init__.py file to make it a proper Python package:

Linux/Mac:

mkdir library/optimizers
touch library/optimizers/__init__.py

Windows PowerShell:

mkdir -Force library/optimizers
New-Item -Path "library/optimizers/__init__.py" -ItemType File -Force

Implementing a Custom Optimizer

You can implement any optimizer you want in this folder. As an example, let’s create a file called compass.py with a custom Compass optimizer implementation:

import torch
from torch.optim import Optimizer


class Compass(Optimizer):
    r"""
    Arguments:
        params (iterable):
            Iterable of parameters to optimize or dicts defining
            parameter groups.
        lr (float):
            Learning rate parameter (default 0.0025)
        betas (Tuple[float, float], optional):
            coefficients used for computing running averages of
            gradient and its square (default: (0.9, 0.999)).
        amp_fac (float):
            amplification factor for the first moment filter (default: 2).
        eps (float):
            Term added to the denominator outside of the root operation to
            improve numerical stability. (default: 1e-8).
        weight_decay (float):
            Weight decay, i.e. a L2 penalty (default: 0).
        centralization (float):
            center model grad (default: 0).
    """

    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        amp_fac=2,
        eps=1e-8,
        weight_decay=0,
        centralization=0,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            amp_fac=amp_fac,
            eps=eps,
            weight_decay=weight_decay,
            centralization=centralization,
        )
        super(Compass, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Compass does not support sparse gradients")

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["ema"] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state["ema_squared"] = torch.zeros_like(p.data)

                ema, ema_squared = state["ema"], state["ema_squared"]
                beta1, beta2 = group["betas"]
                amplification_factor = group["amp_fac"]
                lr = group["lr"]
                weight_decay = group["weight_decay"]
                centralization = group["centralization"]
                state["step"] += 1

                # center the gradient vector
                if centralization != 0:
                    grad.sub_(
                        grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True).mul_(
                            centralization
                        )
                    )

                # bias correction step size
                # soft warmup
                bias_correction = 1 - beta1 ** state["step"]
                bias_correction_sqrt = (1 - beta2 ** state["step"]) ** (1 / 2)
                step_size = lr / bias_correction

                # Decay the first and second moment running average coefficient
                # ema = ema + (1 - beta1) * grad
                ema.mul_(beta1).add_(grad, alpha=1 - beta1)
                # grad = grad + ema * amplification_factor
                grad.add_(ema, alpha=amplification_factor)
                # ema_squared = ema + (1 - beta2) * grad ** 2
                ema_squared.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # lr scaler + eps to prevent zero division
                # denom = exp_avg_sq.sqrt() + group['eps']
                denom = (ema_squared.sqrt() / bias_correction_sqrt).add_(group["eps"])

                if weight_decay != 0:
                    # Perform stepweight decay
                    p.data.mul_(1 - step_size * weight_decay)

                # p = p - lr * grad / denom
                p.data.addcdiv_(grad, denom, value=-step_size)

        return loss

Integrating the Custom Optimizer with SD-Scripts

Next, we need to modify the train_util.py file to add support for our new optimizer. Look for the section where other optimizers are defined and add your custom optimizer:

    elif optimizer_type == "AdamW".lower():
        logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
        optimizer_class = torch.optim.AdamW
        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

    elif optimizer_type == "LodeW".lower():
        logger.info(f"use LodeW optimizer | {optimizer_kwargs}")
        try:
            from library.optimizers.compass import Compass

            optimizer_class = Compass
        except ImportError:
            raise ImportError(
                "Importing Compass failed / インポート Compass が失敗しました。"
            )
        optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

    if optimizer is None:
        # 任意のoptimizerを使う

Using Your Custom Optimizer in Training

Now you can use your new optimizer in your LoRA training command by specifying it with the --optimizer_type parameter:

--optimizer_type=LodeW

Conclusion

Adding custom optimizers to sd-scripts allows you to experiment with different optimization algorithms that may improve your LoRA training results. The Compass optimizer demonstrated in this guide is just one example - you can implement other optimizers, or create your own specialized optimizer tailored to your specific use case.

By taking advantage of custom optimizers, you can potentially achieve better quality LoRAs with shorter training times, making your Stable Diffusion workflow more efficient and productive.