Skip to content
Open
62 changes: 59 additions & 3 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")

fused_optimizers_supported = ['adafactor', 'adamoffload', 'nadamoffload', 'adamwoffload', 'nadamwoffload', 'adanoffload']

if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
Expand Down Expand Up @@ -381,10 +383,25 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function

if (args.optimizer_type in fused_optimizers_supported) and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizers is preventing stochastic/Kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type.lower() not in fused_optimizers_supported:
logger.warning("Kahan summation has been requested, but this is not supported by the selected optimizer.")
if not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
if args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but these are not compatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
optimizer.use_kahan_summation = args.kahan_summation

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
Expand Down Expand Up @@ -437,6 +454,28 @@ def train(args):
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)

# Experimental: some layers have very few weights, and training quality seems
# to increase significantly if these are left in f32 format while training.
if args.fused_backward_pass:

from library.flux_models import MixedLinear
from library.flux_models import RMSNorm

flux.final_layer.linear.to(dtype=torch.float32)
flux.img_in .to(dtype=torch.float32)

for m in flux.modules():
num_params = sum(p.numel() for p in m.parameters())

if isinstance(m, MixedLinear) and m.bias is not None:
m.bias.data = m.bias.data.to(torch.float32)
if m.weight.data.numel() < 20000000: # Includes first Linear stage with 18m weights
m.weight.data = m.weight.data.to(torch.float32)

if isinstance(m, RMSNorm):
m.scale.data = m.scale.data.to(torch.float32)

if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)
Expand Down Expand Up @@ -474,10 +513,21 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
# use fused optimizer for backward pass. Only some specific optimizers are supported.
import library.adafactor_fused

library.adafactor_fused.patch_adafactor_fused(optimizer)
import library.adamw_fused
import library.adan_fused

if args.optimizer_type.lower() == "adafactor":
library.adafactor_fused.patch_adafactor_fused(optimizer)
elif args.optimizer_type.lower() == "adamoffload" or args.optimizer_type.lower() == "adamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, False)
elif args.optimizer_type.lower() == "nadamoffload" or args.optimizer_type.lower() == "nadamwoffload":
library.adamw_fused.patch_adamw_offload_fused(optimizer, True) # Nesterov
elif args.optimizer_type.lower() == "adanoffload":
library.adan_fused.patch_adan_offload_fused(optimizer, False) # Adan
else:
logger.error(f"Optimizer '{args.optimizer_type}' does not have a --fused_backward_pass implementation available")

for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
Expand Down Expand Up @@ -816,6 +866,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float part lost during bf16 quantization, and re-adds it to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
Expand Down
66 changes: 64 additions & 2 deletions library/adafactor_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,62 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result


# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192

def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update):
"""
Copies source into target using Kahan summation.

The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.

Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
state: the optimizer state, used to store kahan residuals
update: the change in weights due to the gradient
"""

# Initialize residuals to 0 for first step
if state.get('kahan_residuals') is None:
state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16)

# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)

# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(state['kahan_residuals'])

# Reverse any rounding up during the cast to bf16 on the previous step
rounded_up = state['kahan_residuals'] >= 32768
source_i32[rounded_up] -= 65536

# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)

# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=state['kahan_residuals'])

# Ensure rounding to bfloat16 matches expectations. These lines may not be
# necessary as target.copy_ should do this rounding anyway.
source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32. Leaves only upper bits in source

# Move the 16-bit Kahan bits from VRAM to main memory
state['kahan_residuals'] = state['kahan_residuals'].to(dtype=torch.uint16).to("cpu")

# Copy the quantized floats into the target tensor
target.copy_(source)


@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
Expand Down Expand Up @@ -102,13 +158,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)

# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)

if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

Expand Down
198 changes: 198 additions & 0 deletions library/adamw_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import math
import torch

from library.adafactor_fused import copy_stochastic_
from library.adafactor_fused import copy_kahan_


def to_float24_bytes(tensor_f32: torch.Tensor) -> torch.Tensor:
"""
Converts a float32 tensor to a 'float24' representation for storage.

This is done by taking the 3 most significant bytes of each float32 element.
On a little-endian system, these are the last 3 bytes.
# TODO - Check this works on Mac, which is a big-endian system

Args:
tensor_f32: The input tensor with dtype torch.float32.

Returns:
A 1D tensor of dtype torch.uint8 containing the packed 'float24' data.
"""
if tensor_f32.dtype != torch.float32:
raise TypeError("Input tensor must be of dtype torch.float32")

tensor_u8 = tensor_f32.view(torch.uint8)
tensor_u8_reshaped = tensor_u8.view(-1, 4)
tensor_f24_bytes = tensor_u8_reshaped[:, 1:]
return tensor_f24_bytes.flatten()


def from_float24_bytes(tensor_f24_u8: torch.Tensor, original_shape: torch.Size) -> torch.Tensor:
"""
Restores a 'float24' byte tensor back to a float32 tensor.

Args:
tensor_f24_u8: A 1D tensor of dtype torch.uint8 from to_float24_bytes.
original_shape: The shape of the original float32 tensor.
device: The device to create the restored tensor on.

Returns:
The restored tensor with dtype torch.float32 and the original shape.
"""
if tensor_f24_u8.dtype != torch.uint8:
raise TypeError("Input byte tensor must be of dtype torch.uint8")
if tensor_f24_u8.numel() % 3 != 0:
raise ValueError("Input byte tensor size must be a multiple of 3")

tensor_u8_3bytes = tensor_f24_u8.view(-1, 3)
padding = torch.zeros(tensor_u8_3bytes.shape[0], 1, dtype=torch.uint8, device=tensor_u8_3bytes.device)
tensor_u8_4bytes = torch.cat([padding, tensor_u8_3bytes], dim=1)
tensor_f32_flat = tensor_u8_4bytes.flatten().view(torch.float32)
return tensor_f32_flat.view(original_shape)


@torch.no_grad()
def adamw_offload_step_param(self, p, group):

if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("This (N)AdamW implementation does not support sparse gradients.")

state = self.state[p]
grad_shape = grad.shape

p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()

# Tensors with few elements may be more sensitive to quantization
# errors, so keep them in float32
high_quality = torch.numel(p) <= 4096

# State Initialization
if len(state) == 0:
state["step"] = 0

data_type = torch.float32 if high_quality else torch.uint16

state['exp_avg'] = torch.zeros_like(p, dtype=data_type)
state['exp_avg_sq'] = torch.zeros_like(p, dtype=data_type)

state["step"] += 1

# NAdam

beta1, beta2 = group['betas']
eps = group['eps'] # 1e-8
weight_decay = group.get('weight_decay', 0.0)

# Bias correction terms
bias_correction1 = 1.0 - math.pow(beta1, state['step'])
bias_correction2 = 1.0 - math.pow(beta2, state['step'])

eps_p2: float = math.pow(eps, 2)

# Bring state back (from CPU, if necessary)

# Recover the exp avg states from however they're stored
def unpack_tensor(state, key, target_device):

# Stored as f24 format?
if state[f'{key}'].dtype == torch.uint8:
return from_float24_bytes(state[f'{key}'].to(target_device), state[f'{key}_shape'])

# bf16 / u16 / f32
return state[f'{key}'].to(target_device).to(dtype=torch.float32)

state['exp_avg'] = unpack_tensor(state, 'exp_avg', p.device)
state['exp_avg_sq'] = unpack_tensor(state, 'exp_avg_sq', p.device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

# Update biased first and second moment estimates
exp_avg .mul_(beta1).add_ (grad, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)

# Compute bias-corrected second moment for denominator
exp_avg_sq_corrected = exp_avg_sq / bias_correction2

# Compute update based on whether Nesterov momentum (NAdam) is being used
if self.use_nesterov:
# The next step's bias correction for momentum is needed
bias_correction1_next = 1.0 - math.pow(beta1, state['step'] + 1)

# NAdam update: combines current gradient with momentum look-ahead
momentum_cache = exp_avg / bias_correction1_next
update = (beta1 * momentum_cache + (1.0 - beta1) * grad / bias_correction1) / (exp_avg_sq_corrected.sqrt() + eps)
else:
# Standard Adam update: use bias-corrected first moment directly
exp_avg_corrected = exp_avg / bias_correction1
update = exp_avg_corrected / (exp_avg_sq_corrected.sqrt() + eps)

lr: float = group['lr']

# Implement 'cautious optimizer' from https://arxiv.org/pdf/2411.16085
# The scaling factor - dividing by m.mean() - does not seem to work with parameter
# groups, but it also appears to be an optional step, so it has been removed.
m = (update * grad >= 0).to(grad.dtype)
update = update * m #/ (m.mean() + eps)

# Apply learning rate
update.mul_(lr)

# Apply weight decay
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)

# Reduce the size of large exp_avg and exp_avg_sq tensors to 24-bit,
# and then move them to cpu memory
if not high_quality:
state[f'exp_avg_shape'] = state[f'exp_avg'].shape
state[f'exp_avg'] = to_float24_bytes(state[f'exp_avg']).to('cpu')

state[f'exp_avg_sq_shape'] = state[f'exp_avg_sq'].shape
state[f'exp_avg_sq'] = to_float24_bytes(state[f'exp_avg_sq']).to('cpu')

# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)

if p.dtype == torch.bfloat16:
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state, update)
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)


@torch.no_grad()
def adamw_offload_step(self, closure=None):
"""
Performs a single optimization step

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
adamw_offload_step_param(self, p, group)

return loss


def patch_adamw_offload_fused(optimizer, use_nesterov):
optimizer.use_nesterov = use_nesterov

optimizer.step_param = adamw_offload_step_param.__get__(optimizer)
optimizer.step = adamw_offload_step.__get__(optimizer)
Loading