Skip to content
23 changes: 23 additions & 0 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,27 @@ def train(args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function

if args.optimizer_type == "adafactor" and args.full_bf16:
logger.warning("Use of --blockwise_fused_optimizer with Adafactor optimizer prevents stochastic/kahan weight updates.")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)

# Pass any Kahan summation arg to the optimizer
if args.kahan_summation:
# Self check parameter compatibility
if args.optimizer_type != "adafactor":
logger.warning("Kahan summation has been requested, but currently this is only supported by the supplied Adafactor optimizer.")
elif not args.full_bf16:
logger.warning("Kahan summation requires --full_bf16")
elif args.blockwise_fused_optimizers:
logger.warning("Kahan summation has been requested, but it is incompatible with --blockwise_fused_optimizer. "\
"Perhaps try --fused_backward_pass instead.")
else:
logger.info("Using Kahan summation")
optimizer.use_kahan_summation = args.kahan_summation

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
Expand Down Expand Up @@ -815,6 +832,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--kahan_summation",
action="store_true",
help="Offloads to CPU the float parts lost during bf16 quantization, and re-adds them to the next step / "\
"bf16 量子化中に失われた浮動小数点部分を CPU にオフロードし、次のステップに再度追加します",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
Expand Down
76 changes: 74 additions & 2 deletions library/adafactor_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,72 @@ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
del result


# Kahan summation for bfloat16
# The implementation was provided by araleza.
# Based on paper "Revisiting BFloat16 Training": https://arxiv.org/pdf/2010.06192

kahan_residuals = []
tensor_index = 0
prev_step = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since step starts from 0, it would be better to set this to -1. The tensor_index of the first step starts from 1, which will cause a mismatch with the next step.


def copy_kahan_(target: torch.Tensor, source: torch.Tensor, step, update):
"""
Copies source into target using Kahan summation.

The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.

Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
Comment on lines +36 to +46
Copy link

Copilot AI Aug 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly describes the source parameter. Both target and source are described as 'the target tensor' - the source parameter should be described as 'the source tensor with dtype=float32'.

Suggested change
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
source: the source tensor with dtype=float32

Copilot uses AI. Check for mistakes.
step: the global training step count
update: the change in weights due to the gradient
"""
global kahan_residuals, tensor_index, prev_step

# Calculate the group index of the current residual Tensor. Tensors
# pass through this copy function in the same order at each step.
tensor_index += 1
if prev_step != step: # Starting new step?
prev_step = step
tensor_index = 0

# Initialize residuals to 0 for first step
if len(kahan_residuals) <= tensor_index:
kahan_residuals += [torch.zeros_like(source, dtype=torch.int16)]

# Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations
kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(source.device).to(dtype=torch.int32)

# Bring the previous step's lower bits of the weights back from the
# cpu device, and add them back to the weights of the current step.
source_i32 = source.view(dtype=torch.int32) # Can't do math on uint32
source_i32.add_(kahan_residuals[tensor_index])

# If the Kahan residual was >=0.5 then the cast to bf16 rounded up
rounded_up = kahan_residuals[tensor_index] >= 32768
source_i32[rounded_up] -= 65536

# Must add the gradient update after the bottom bits are restored in case
# the exponent is changed by the update, or the -65536 on the line above
# would drop the uint32 value below zero, which is invalid.
source.add_(-update)

# Get the lower bits into the residual
torch.bitwise_and(source_i32, 0x0000FFFF, out=kahan_residuals[tensor_index])

source_i32.add_(32768) # Add offset so clipping bits performs round-to-nearest
source_i32.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32 # Leave only upper bits in source

# Move the 16-bit Kahan bits from VRAM to main memory
kahan_residuals[tensor_index] = kahan_residuals[tensor_index].detach().to(dtype=torch.uint16).to("cpu")

# Copy the quantized floats into the target tensor
target.copy_(source)


@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
Expand Down Expand Up @@ -102,13 +168,19 @@ def adafactor_step_param(self, p, group):
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))

p_data_fp32.add_(-update)
# Add on gradient update, but not if using kahan summation as the bottom
# bits must be restored first. (This update occurs in copy_kahan_() instead)
if not self.optimizer.use_kahan_summation:
p_data_fp32.add_(-update)

# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)

if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
if self.optimizer.use_kahan_summation:
copy_kahan_(p, p_data_fp32, state["step"])
else:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)

Expand Down