-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Add wavelet loss for networks #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sd3
Are you sure you want to change the base?
Conversation
train_network.py
Outdated
|
||
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float()) | ||
# Weight the losses as needed | ||
loss = loss + args.wavelet_loss_alpha * wav_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe do loss = (1.0 - args.wavelet_loss_alpha) * loss + args.wavelet_loss_alpha * wav_loss
so its a proper interpolation?
You should consider including a weighting/masking scheme for the different levels, I am getting amazing results from it. I've been playing with a prototype of this myself, see #294 (reply in thread) For example, masking the lowpass elements makes it easier to learn subjects and objects without transferring the overall image aesthetic bias. Here is my hacky training code as an example:
Maybe we could have a parameter to pass an array of loss weights, one for each level of detail? For example a 1024px image can be decomposed up to 8 levels. My testing gives me interesting results when masking or weighting certain levels differently. |
@recris I went through and added inputs for band weighting, allowing a good amount of customization. It was in there previously for SWT but now it is applied to both. The default weights are low so as to allow one to customize how much impact each band has. The alpha is now more like a multiplier instead but should control the weighting via the bands ideally. I added wavelet loss to the logging, which can help when adjusting the band weights to see the wavelet impacts.
I am still working through some trainings of this to get some more examples and comparisons. |
Traceback (most recent call last): getting this error when trying it |
@EClipXAi Apologies, some newer functionality leaked in there. I have fixed this. Give it another shot and let me know if you have any issues. |
- add full conditional_loss functionality to wavelet loss - Transforms are separate and abstracted - Loss now doesn't include LL except the lowest level - ll_level_threshold allows you to control the level the ll is used in the loss - band weights can now be passed in - rectified flow calculations can be bypassed for experimentation - Fixed alpha to 1.0 with new weighted bands producing lower loss
cea1930
to
3b949b9
Compare
Added QuaterionWaveletTransform derived from https://arxiv.org/abs/2505.00334.
It should work similarly to DWT but with more components (4 vs 1 with DWT) using hilbert filters. It's probably experimental so true recommended values I'm not quite sure yet.
Additionally reworked SWT to work better, and SWT should be more performant now using 1d convolutions.
Added tests for the Wavelet tranforms and Wavelet loss to make sure it's working as expected. |
Sharing a good resource for comparing wavelets: https://www.mathworks.com/help/wavelet/gs/introduction-to-the-wavelet-families.html |
Some papers about wavelets:
https://arxiv.org/abs/2306.00306
https://arxiv.org/abs/2402.19215
https://arxiv.org/abs/2211.16152
https://arxiv.org/abs/2407.12538
https://arxiv.org/abs/2404.11273
Video about Wavelets
https://www.youtube.com/watch?v=jnxqHcObNK4
Example from Training Generative Image Super-Resolution Models by Wavelet-Domain Losses Enables Better Control of Artifacts
An example to showcase what this is suppose to help with.
Install
pip install PyWavelets
Usage
Activate Wavelet Loss:
Configure
Configure Wavelet loss:
Recommended starting point:
Huber or l1 are recommended to get the right loss for the signals. Level 2 gets more detail so will be best to capture those fine details. Need wavelet_loss=true to enable wavelet_loss.
CLI
Wavelet families:
We use a custom loss implementation and I just learned about Wavelets like yesterday so we may consider how the approach works and maybe utilize other libraries as well.
I am using flow matching/rectified flow attempts to predict denoised latents to create a better result for wavelets, so might not work as well for some models like SD1.5/SDXL but I am not sure.
Related #2016