- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.1k
Big image quality improvement! Kahan summation for Adafactor-optimized Flux FFT #2159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sd3
Are you sure you want to change the base?
Conversation
| I've now accelerated the kahan summation function by only sending the 16 lower bits that are clipped from the f32 values to the cpu, instead of a floating point offset. With only half of the number of bytes needing to be transferred, it's much quicker. It might be higher quality too, as the bit identical f32 value that was quantized to bf16 is now restored. | 
| @araleza this sounds amazing i recently compared majority of optimizers and adafactor was king definitely gonna test your branch can you share your example toml / json? so i need to add --kahan_summation what other? | 
| I don't usually use toml / json, so I don't have one to share @FurkanGozukara. But yeah, switching this on is literally just adding --kahan_summation to the command line and that's that. I don't know how to add that to a toml / json file, but it's just a command line option like the rest, so however you add those, it should work for --kahan_summation too. I should really add a debug output message that says it's switched on, so you can confirm it's working. Give me a minute and I'll check that in now. | 
| 
 whatever format you use fine if you put here as a txt file i can compare what i am missing. i will definitely run a comparison training | 
| I'm running directly from the command line like this, @FurkanGozukara: 
 (I switched out my already-trained FFT model for the base flux1-dev model in that command line. If you're actually training from the base with kahan summation, you probably want a higher LR of something like 2e-6, rather than a polishing LR of 4e-7) Edit: I removed the  | 
| 
 The StableAdamW optimizer also supports Kahan summation and can be used with sd-scripts without any issues. However, I haven’t noticed any significant performance differences compared to other optimizers. It might still be worth a look. https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.StableAdamW optimizer_type = "pytorch_optimizer.StableAdamW" | 
| Hi @blackmagic24, thanks for your reply. I did notice that the pytorch_optimizer's adamW implementation had a feature that it referred to as a kahan_sum, but I think it might be an unrelated feature that also shares the name of the technique, as kahan summation is an idea that's applicable to more than one area. When I looked at the source code for pytorch_optimizer's adamw implementation, the only changes that activating its kahan_sum feature does is this:   That's from: I'm not exactly certain what that section of code does, but it doesn't seem to be keeping the lower bits of the weights between steps, but rather it's doing something with the exponential averages. And it doesn't seem to be sending the bits to the CPU between training steps, so it doesn't seem like it has the same purpose as what my Kahan summation code is doing. But if you have any more information about this feature for adamW, I'd be interested to know about it. | 
| Actually, looking more at that AdamW code, it does seem to be doing a floating point sub_() of the bfloat16 value against the 'real' value, and an add_() to the kahan_comp variable. The parts with the exponential moving averages are the same between the kahan version and the non-kahan version, so it's not related to that at all. I think this code is more like the first version of my pull request, where I kept an f32 offset on the CPU between training steps, rather than directly keeping the lower 16 bits that were lost. It isn't been sent to the CPU between steps though, so it wouldn't work in low-memory conditions like the Adafactor version in my pull request would. An interesting thing is that that pytorch_optimizer's version of AdamW doesn't do stochastic rounding. While it's an interesting question whether Kahan summation beats stochastic rounding for Flux training or not, both of these should strongly outperform an optimizer that isn't doing either of them. So you should see a large improvement for that AdamW optimizer if you switch on kahan_sum. And the fact that you aren't is curious. I'm sure this is unlikely, but you're not actually training a LoRA instead of doing full fine-tuning are you? Cause LoRAs are f32 all the time, so kahan_sum would have no effect there. I might give that AdamW optimizer a try at some point soon, although I'll need to check if it even can be run on my 32 GB graphics card when doing full fine tuning. | 
| i have recently compared AdamW vs Adafactor and Adafactor yielded best realistic results | 
| 
 That might be because sd-scripts' current Adafactor does have stochastic rounding, and the AdamW implementation doesn't. AdamW should typically outperform Adafactor if both are using stochastic rounding - or both aren't - but AdamW uses much more memory, which is why projects often switch to Adafactor in low memory conditions. | 
| 
 Yep, I can't run fine tuning with pytorch_optimizer.StableAdamW, even on my 32 GB graphics card. If I try to use that optimizer, I get the error message: If I switch off  | 
| 
 this makes sense ty for info | 
| Thanks, sorry for the late response. This is very interesting. I'll take a closer look at the code. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that having kahan_residuals on the CPU reduces the amount of GPU memory required. This is a kind of CPU offloading.
I think the code would be more organized if kahan_residuals were stored in state, but is that possible?
        
          
                library/adafactor_fused.py
              
                Outdated
          
        
      |  | ||
| kahan_residuals = [] | ||
| tensor_index = 0 | ||
| prev_step = 0 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since step starts from 0, it would be better to set this to -1. The tensor_index of the first step starts from 1, which will cause a mismatch with the next step.
| 
 Yes, I've now made this change - and the code is simpler now too. I think that the step and tensor_index values that you mentioned were already correct, but perhaps the code was not clear here. However, now I'm using the optimizer state to store the kahan residuals, both step and tensor_index have been deleted from the copy_kahan_() function, so they are definitely not a problem now. 😊 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces Kahan summation as an alternative to stochastic rounding for bfloat16 training with the Adafactor optimizer in Flux model training. This technique preserves the precision lost during bf16 quantization by storing the residual bits and reapplying them in subsequent training steps.
Key changes:
- Implements Kahan summation algorithm that offloads quantization residuals to CPU memory
- Adds --kahan_summationcommand-line argument with compatibility validation
- Modifies the Adafactor optimizer step to conditionally use Kahan summation instead of stochastic rounding
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description | 
|---|---|
| library/adafactor_fused.py | Implements copy_kahan_function and integrates it into the Adafactor parameter update process | 
| flux_train.py | Adds Kahan summation CLI argument, validation logic, and passes the setting to the optimizer | 
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| """ | ||
| Copies source into target using Kahan summation. | ||
| The lower bits of the float32 weight that are lost on conversion to bfloat16 | ||
| are sent to the CPU until the next step, where they are re-added onto the weights | ||
| before adding the gradient update. This produces near float32-like weight behavior, | ||
| although the copies back and forth to main memory result in slower training steps. | ||
| Args: | ||
| target: the target tensor with dtype=bfloat16 | ||
| source: the target tensor with dtype=float32 | 
    
      
    
      Copilot
AI
    
    
    
      Aug 24, 2025 
    
  
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring incorrectly describes the source parameter. Both target and source are described as 'the target tensor' - the source parameter should be described as 'the source tensor with dtype=float32'.
| """ | |
| Copies source into target using Kahan summation. | |
| The lower bits of the float32 weight that are lost on conversion to bfloat16 | |
| are sent to the CPU until the next step, where they are re-added onto the weights | |
| before adding the gradient update. This produces near float32-like weight behavior, | |
| although the copies back and forth to main memory result in slower training steps. | |
| Args: | |
| target: the target tensor with dtype=bfloat16 | |
| source: the target tensor with dtype=float32 | |
| source: the source tensor with dtype=float32 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
| I've also included the Kahan summation code in my separate (and more impactful / higher image quality) AdamW for 5090 branch here: | 
If you've been doing full fine tuning with Flux, you've likely seen the image quality improvements that the stochastic rounding feature (available when using
--fused_backward_passwith the supplied Adafactor optimizer) can offer. That feature improves the casting of the f32 'weight + gradient update' values back to bf16 during each training step by randomly moving the value either up or down to the next rounded bfloat16 value, depending on the distance the f32 value is between the two adjacent quantized bf16 values.But although stochastic rounding offers a large improvement to bf16 training, its randomness doesn't always produce ideal results. As the "Revisiting BFloat16 Training" paper - https://arxiv.org/pdf/2010.06192 - notes:
This same paper then goes on to note a different technique - Kahan summation - that does match the performance of 32-bit training:
Kahan summation is a technique where the f32 values are rounded to the nearest bf16 value, and then the 16-bit offset between the original f32 and the quantized bf16 value is recorded. This offset is then sent back to the CPU so it doesn't use up VRAM past that point. Then, on the next training step, instead of just taking the bf16 weight value plus the gradient update for that step, the offset that was lost from the previous step is added back on as well. And then that process is repeated.
It means that a bf16 value that is updating by (e.g.) 20% of the distance between that bf16 value and the next adjacent bf16 value will go up once in every 5 training steps, rather than stochastic rounding's random approximation of once in every 5 steps. Stochastic rounded values bounce up and down fairly unpredictably, but the Kahan summed values are more stable.
The technique comes at a price: training steps take around 40% longer in my tests so far. (I'm training at batch size 5 on a 5090 RTX card). The slowdown comes from copying the values from the GPU to main memory and back again on each step. It's possible to use stochastic rounding for most of the training, and then switch to Kahan summation with
--kahan_summationfor the final polish phase, but doing full runs start-to-finish with kahan summation works fine too - and I'd recommend you do that, for best quality.Now I've implemented Kahan summation in sd-scripts, the quality improvements that it achieves are impressive. Flux.dev training does seem to be one of the cases where Kahan summation significantly exceeds stochastic rounding in terms of image quality. Very low LRs such as 5e-7 (which suffer quantization randomness with stochastic rounding) work great with Kahan summation.
If you (e.g. @kohya-ss) want to see what this feature can do, then why not grab this branch, switch on
--kahan_summation, set the LR to 5e-7, and then try running a final polish pass on an already-trained FFT model that you have? The quality improvements are almost immediate, as it allows the weights to 'settle' into a more stable pattern.