Skip to content

Add wavelet loss for networks #2037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: sd3
Choose a base branch
from

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Apr 8, 2025

Some papers about wavelets:
https://arxiv.org/abs/2306.00306
https://arxiv.org/abs/2402.19215
https://arxiv.org/abs/2211.16152
https://arxiv.org/abs/2407.12538
https://arxiv.org/abs/2404.11273

Video about Wavelets
https://www.youtube.com/watch?v=jnxqHcObNK4

Example from Training Generative Image Super-Resolution Models by Wavelet-Domain Losses Enables Better Control of Artifacts

An example to showcase what this is suppose to help with.

WGSR WGSR
Screenshot 2025-04-07 at 20-08-04 2402 19215v1 pdf Screenshot 2025-04-07 at 20-27-18 2402 19215v1 pdf

Screenshot 2025-04-07 at 18-41-54 Wavelets a mathematical microscope - Invidious

Wavelet examples from: https://www.youtube.com/watch?v=jnxqHcObNK4

Install

pip install PyWavelets

Usage

Activate Wavelet Loss:

wavelet_loss = true

Configure

Configure Wavelet loss:

wavelet_loss = true
wavelet_loss_type = "huber" # l2, l1, smooth_l2, huber. l1 and huber would be recommended
wavelet_loss_transform = "swt" # dwt, swt. swt keeps the spatial details but dwt can be a little more efficient.
wavelet_loss_wavelet = "sym7" # over 100 wavelets, but try db4, sym7. AI toolkit uses haar. 
wavelet_loss_level = 2 # Loss level is how many levels we process. DWT maybe 1-3, SWT 1-2. SWT level 2 focuses on details more. 
wavelet_loss_alpha = 1.0 # How much impact the loss has on the latent loss value
wavelet_loss_band_level_weights = { "ll1" = 0.1, "lh1" = 0.01, "hl1" = 0.01, "hh1" = 0.05, "ll2" = 0.1, "lh2" = 0.01, "hl2" = 0.01, "hh2" = 0.05 } # Set the individual levels band weights. Starts at level 1
wavelet_loss_band_weights = { "ll" = 0.1, "lh" = 0.01, "hl" = 0.01, "hh" = 0.05 } # Sets the defaults for the bands. Currently need to set all the values or bad things might happen.
wavelet_loss_ll_level_threshold = -1 # level to process the ll at. Low frequency will be similar to the original latent so only need the last levels for that detail.
wavelet_loss_rectified_flow = true # Experimental. Not recommended to change, but toggles rectified flow to get clean latents.

Recommended starting point:

wavelet_loss = true
wavelet_loss_type = "huber" # l2, l1, smooth_l2, huber. l1 and huber would be recommended
wavelet_loss_level = 2 # Loss level is how many levels we process. DWT maybe 1-3, SWT 1-2. SWT level 2 focuses on details more. 

Huber or l1 are recommended to get the right loss for the signals. Level 2 gets more detail so will be best to capture those fine details. Need wavelet_loss=true to enable wavelet_loss.

CLI

--wavelet_loss --wavelet_loss_type huber --wavelet_loss_level 2

Wavelet families:

    haar family: haar
    db family: db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23, db24, db25, db26, db27, db28, db29, db30, db31, db32, db33, db34, db35, db36, db37, db38
    sym family: sym2, sym3, sym4, sym5, sym6, sym7, sym8, sym9, sym10, sym11, sym12, sym13, sym14, sym15, sym16, sym17, sym18, sym19, sym20
    coif family: coif1, coif2, coif3, coif4, coif5, coif6, coif7, coif8, coif9, coif10, coif11, coif12, coif13, coif14, coif15, coif16, coif17
    bior family: bior1.1, bior1.3, bior1.5, bior2.2, bior2.4, bior2.6, bior2.8, bior3.1, bior3.3, bior3.5, bior3.7, bior3.9, bior4.4, bior5.5, bior6.8
    rbio family: rbio1.1, rbio1.3, rbio1.5, rbio2.2, rbio2.4, rbio2.6, rbio2.8, rbio3.1, rbio3.3, rbio3.5, rbio3.7, rbio3.9, rbio4.4, rbio5.5, rbio6.8
    dmey family: dmey
    gaus family: gaus1, gaus2, gaus3, gaus4, gaus5, gaus6, gaus7, gaus8
    mexh family: mexh
    morl family: morl
    cgau family: cgau1, cgau2, cgau3, cgau4, cgau5, cgau6, cgau7, cgau8
    shan family: shan
    fbsp family: fbsp
    cmor family: cmor

We use a custom loss implementation and I just learned about Wavelets like yesterday so we may consider how the approach works and maybe utilize other libraries as well.

I am using flow matching/rectified flow attempts to predict denoised latents to create a better result for wavelets, so might not work as well for some models like SD1.5/SDXL but I am not sure.

Related #2016

@rockerBOO rockerBOO changed the title Add wavelet loss Add wavelet loss for networks Apr 8, 2025
train_network.py Outdated

wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
# Weight the losses as needed
loss = loss + args.wavelet_loss_alpha * wav_loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe do loss = (1.0 - args.wavelet_loss_alpha) * loss + args.wavelet_loss_alpha * wav_loss so its a proper interpolation?

@recris
Copy link

recris commented Apr 11, 2025

You should consider including a weighting/masking scheme for the different levels, I am getting amazing results from it.

I've been playing with a prototype of this myself, see #294 (reply in thread)

For example, masking the lowpass elements makes it easier to learn subjects and objects without transferring the overall image aesthetic bias.

Here is my hacky training code as an example:

use_wavelet_loss = True
        wavelet_loss_ratio = 0.98

        huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
        loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)

        if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
            # not applying the mask by multiplying the loss, instead this scales the gradients directly
            # it enables using other loss functions below without caring about having to worry
            # about mask being incompatible
             loss, mask = apply_masked_loss(loss, batch)
            mask = torch.nn.functional.interpolate(mask, size=noise_pred.shape[2:], mode="area")
            noise_pred.register_hook(lambda grad: grad * mask.to(grad.dtype))

        loss = loss.mean([1, 2, 3])

        if use_wavelet_loss:
            # custom weighting scheme, this should be parameter
            num_levels = 3
            level_weights = [1.0] * num_levels
            lowpass_weight = 0.0
            level_weights[1] = 0.0
            level_weights[2] = 0.0

            # this requires pytorch-wavelets to be installed
            dwt = DWTForward(J=num_levels, mode='zero', wave='haar').to(device=accelerator.device, dtype=vae_dtype)

            model_pred_xl, model_pred_xh = dwt(noise_pred)
            target_xl, target_xh = dwt(target)

            # compute lowpass loss
            wt_loss = train_util.conditional_loss(model_pred_xl.float(), target_xl.float(), args.loss_type, "none", huber_c)
            wt_loss = wt_loss * lowpass_weight

            # compute loss for each band
            for lvl, (p, t) in enumerate(zip(reversed(model_pred_xh), reversed(target_xh))):
                l = train_util.conditional_loss(p.float(), t.float(), args.loss_type, "none", huber_c)
                l = l * level_weights[lvl]

                l_xlh, l_xhl, l_xhh = torch.unbind(l, dim=2)

                wt_loss = torch.cat((
                    torch.cat((wt_loss, l_xlh), dim=3),
                    torch.cat((l_xhl, l_xhh), dim=3)),
                    dim=2)

            wt_loss = wt_loss.mean([1, 2, 3])

            loss = wavelet_loss_ratio * wt_loss + (1 - wavelet_loss_ratio) * loss

Maybe we could have a parameter to pass an array of loss weights, one for each level of detail? For example a 1024px image can be decomposed up to 8 levels. My testing gives me interesting results when masking or weighting certain levels differently.

@rockerBOO
Copy link
Contributor Author

@recris I went through and added inputs for band weighting, allowing a good amount of customization. It was in there previously for SWT but now it is applied to both. The default weights are low so as to allow one to customize how much impact each band has. The alpha is now more like a multiplier instead but should control the weighting via the bands ideally.

I added wavelet loss to the logging, which can help when adjusting the band weights to see the wavelet impacts.
Screenshot 2025-04-12 at 02-18-58 women-flux-kohya-lora Workspace – Weights   Biases

Example Example
c-f1-2025-04-12_00059_ c-f1-2025-04-12_00058_
c-f1-2025-04-12_00028_ c-f1-2025-04-12_00027_
c-f1-2025-04-12_00008_ c-f1-2025-04-12_00024_

I am still working through some trainings of this to get some more examples and comparisons.

@EClipXAi
Copy link

Traceback (most recent call last):
File "/workspace/kohya_ss/sd-scripts/flux_train_network.py", line 559, in
trainer.train(args)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 1473, in train
loss, wav_loss = self.process_batch(
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 496, in process_batch
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/kohya_ss/sd-scripts/library/custom_train_functions.py", line 804, in forward
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 488, in loss_fn
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
TypeError: get_huber_threshold_if_needed() takes 3 positional arguments but 4 were given
Traceback (most recent call last):
File "/workspace/kohya_ss/sd-scripts/flux_train_network.py", line 559, in
trainer.train(args)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 1473, in train
loss, wav_loss = self.process_batch(
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 496, in process_batch
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/workspace/kohya_ss/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/workspace/kohya_ss/sd-scripts/library/custom_train_functions.py", line 804, in forward
band_loss = self.band_level_weights.get(weight_key, self.band_weights[band]) * self.loss_fn(pred_stack, target_stack)
File "/workspace/kohya_ss/sd-scripts/train_network.py", line 488, in loss_fn
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
TypeError: get_huber_threshold_if_needed() takes 3 positional arguments but 4 were given

getting this error when trying it

@rockerBOO
Copy link
Contributor Author

@EClipXAi Apologies, some newer functionality leaked in there. I have fixed this. Give it another shot and let me know if you have any issues.

rockerBOO added 13 commits May 4, 2025 18:38
- add full conditional_loss functionality to wavelet loss
- Transforms are separate and abstracted
- Loss now doesn't include LL except the lowest level
  - ll_level_threshold allows you to control the level the ll is
    used in the loss
- band weights can now be passed in
- rectified flow calculations can be bypassed for experimentation
- Fixed alpha to 1.0 with new weighted bands producing lower loss
@rockerBOO rockerBOO force-pushed the network-wavelet-loss branch from cea1930 to 3b949b9 Compare May 4, 2025 22:58
@rockerBOO
Copy link
Contributor Author

rockerBOO commented May 5, 2025

Added QuaterionWaveletTransform derived from https://arxiv.org/abs/2505.00334.

wavelet_loss_transform = "qwt"

wavelet_loss = true
wavelet_loss_type = "huber"
wavelet_loss_transform = "qwt"
wavelet_loss_wavelet = "sym7"
wavelet_loss_level = 3
wavelet_loss_alpha = 1
wavelet_loss_band_weights = { "ll" = 0.25, "lh" = 1.0, "hl" = 1.0, "hh" = 1.0 }
wavelet_loss_quaternion_component_weights = { "r" = 0.25, "i" = 0.5, "j" = 0.5, "k" = 0.5 }
wavelet_loss_ll_level_threshold = -1

It should work similarly to DWT but with more components (4 vs 1 with DWT) using hilbert filters. It's probably experimental so true recommended values I'm not quite sure yet.

wavelet_loss_quaternion_component_weights keys might change to r, x, y, xy but will mention when/if it changes.

Additionally reworked SWT to work better, and SWT should be more performant now using 1d convolutions.

wavelet_loss_transform wasn't properly being used before so now swt and qwt should work where it was only using DWT before.

Added tests for the Wavelet tranforms and Wavelet loss to make sure it's working as expected.

67372a added a commit to 67372a/sd-scripts that referenced this pull request May 13, 2025
@67372a
Copy link

67372a commented May 25, 2025

Sharing a good resource for comparing wavelets:

https://www.mathworks.com/help/wavelet/gs/introduction-to-the-wavelet-families.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants