Skip to content

Support Lumina-image-2.0 #1927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 69 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d154e76
init
sdbds Feb 12, 2025
c0caf33
update
sdbds Feb 15, 2025
7323ee1
update lora_lumina
sdbds Feb 15, 2025
a00b06b
Lumina 2 and Gemma 2 model loading
rockerBOO Feb 15, 2025
60a76eb
Add caching gemma2, add gradient checkpointing, refactor lumina model…
rockerBOO Feb 16, 2025
1601563
Update metadata.resolution for Lumina 2
rockerBOO Feb 16, 2025
6965a01
Merge pull request #12 from rockerBOO/lumina-model-loading
sdbds Feb 16, 2025
733fdc0
update
sdbds Feb 17, 2025
3ce23b7
Merge branch 'lumina' of https://github.com/sdbds/sd-scripts into lumina
sdbds Feb 17, 2025
bb7bae5
Merge pull request #13 from rockerBOO/lumina-cache-checkpointing
sdbds Feb 17, 2025
aa36c48
update for always use gemma2 mask
sdbds Feb 17, 2025
44782dd
Fix validation epoch divergence
rockerBOO Feb 14, 2025
3365cfa
Fix sizes for validation split
rockerBOO Feb 17, 2025
3ed7606
Clear sizes for validation reg images to be consistent
rockerBOO Feb 17, 2025
1aa2f00
Fix validation epoch loss to check epoch average
rockerBOO Feb 16, 2025
98efbc3
Add documentation to model, use SDPA attention, sample images
rockerBOO Feb 18, 2025
bd16bd1
Remove unused attention, fix typo
rockerBOO Feb 18, 2025
6597631
Merge pull request #14 from rockerBOO/samples-attention
sdbds Feb 19, 2025
025cca6
Fix samples, LoRA training. Add system prompt, use_flash_attn
rockerBOO Feb 23, 2025
6d7bec8
Remove non-used code
rockerBOO Feb 23, 2025
42a8015
Fix system prompt in datasets
rockerBOO Feb 23, 2025
ba725a8
Set default discrete_flow_shift to 6.0. Remove default system prompt.
rockerBOO Feb 23, 2025
48e7da2
Add sample batch size for Lumina
rockerBOO Feb 24, 2025
2c94d17
Fix typo
rockerBOO Feb 24, 2025
653621d
Merge pull request #15 from rockerBOO/samples-training
sdbds Feb 24, 2025
fc772af
1、Implement cfg_trunc calculation directly using timesteps, without i…
sdbds Feb 24, 2025
5f9047c
add truncation when > max_length
sdbds Feb 25, 2025
ce37c08
clean code and add finetune code
sdbds Feb 26, 2025
a1a5627
fix shift
sdbds Feb 26, 2025
7b83d50
Merge branch 'sd3' into lumina
rockerBOO Feb 27, 2025
70403f6
fix cache text encoder outputs if not using disk. small cleanup/align…
rockerBOO Feb 27, 2025
542f980
Fix sample norms in batches
rockerBOO Feb 27, 2025
0886d97
Add block swap
rockerBOO Feb 27, 2025
ce2610d
Change system prompt to inject Prompt Start special token
rockerBOO Feb 27, 2025
42fe22f
Enable block swap for Lumina
rockerBOO Feb 27, 2025
9647f1e
Fix validation block swap. Add custom offloading tests
rockerBOO Feb 28, 2025
d6f7e2e
Fix block swap for sample images
rockerBOO Feb 28, 2025
1bba7ac
Add block swap in sample image timestep loop
rockerBOO Feb 28, 2025
a2daa87
Add block swap for uncond (neg) for sample images
rockerBOO Feb 28, 2025
cad182d
fix torch compile/dynamo for Gemma2
rockerBOO Feb 28, 2025
a69884a
Add Sage Attention for Lumina
rockerBOO Mar 2, 2025
3817b65
Merge pull request #16 from rockerBOO/lumina
sdbds Mar 2, 2025
800d068
Merge pull request #17 from rockerBOO/lumina-cache-text-encoder-outputs
sdbds Mar 2, 2025
d6c3e63
Merge pull request #18 from rockerBOO/fix-sample-batch-norms
sdbds Mar 2, 2025
b5d1f1c
Merge pull request #19 from rockerBOO/lumina-block-swap
sdbds Mar 2, 2025
b6e4194
Merge pull request #20 from rockerBOO/lumina-system-prompt-special-token
sdbds Mar 2, 2025
dfe1ab6
Merge pull request #21 from rockerBOO/lumina-torch-dynamo-gemma2
sdbds Mar 2, 2025
09c4710
Merge pull request #22 from rockerBOO/sage_attn
sdbds Mar 3, 2025
5e45df7
update gemma2 train attention layer
sdbds Mar 4, 2025
1f22a94
Update embedder_dims, add more flexible caption extension
rockerBOO Mar 4, 2025
9fe8a47
Undo dropout after up
rockerBOO Mar 4, 2025
e8c15c7
Remove log
rockerBOO Mar 4, 2025
7482784
Merge pull request #23 from rockerBOO/lumina-lora
sdbds Mar 9, 2025
2ba1cc7
Fix max norms not applying to noise
rockerBOO Mar 22, 2025
61f7283
Fix non-cache vae encode
rockerBOO Mar 22, 2025
1481217
Merge pull request #25 from rockerBOO/lumina-fix-non-cache-image-vae-…
sdbds Mar 22, 2025
3000816
Merge pull request #24 from rockerBOO/lumina-fix-max-norms
sdbds Mar 22, 2025
00e12ee
update for lost change
sdbds Apr 6, 2025
1a4f1ff
Merge branch 'lumina' of https://github.com/sdbds/sd-scripts into lumina
sdbds Apr 6, 2025
9f1892c
Merge branch 'sd3' into lumina
sdbds Apr 6, 2025
7f93e21
fix typo
sdbds Apr 6, 2025
899f345
update for init problem
sdbds Apr 23, 2025
4fc9178
fix bugs
sdbds Apr 23, 2025
0145efc
Merge branch 'sd3' into lumina
rockerBOO Jun 9, 2025
d94bed6
Add lumina tests and fix image masks
rockerBOO Jun 10, 2025
77dbabe
Merge pull request #26 from rockerBOO/lumina-test-fix-mask
sdbds Jun 10, 2025
1db7855
Merge branch 'sd3' into update-sd3
rockerBOO Jun 16, 2025
0e929f9
Revert system_prompt for dataset config
rockerBOO Jun 16, 2025
8e4dc1f
Merge pull request #28 from rockerBOO/lumina-train_util
sdbds Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None


Expand Down Expand Up @@ -107,6 +108,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
system_prompt: Optional[str] = None
resize_interpolation: Optional[str] = None

@dataclass
Expand Down Expand Up @@ -197,6 +199,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"system_prompt": str,
"resize_interpolation": str,
}
# DO means DropOut
Expand Down Expand Up @@ -243,6 +246,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"system_prompt": str,
"resize_interpolation": str,
}

Expand Down Expand Up @@ -530,6 +534,7 @@ def print_info(_datasets, dataset_type: str):
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
system_prompt: {dataset.system_prompt}
""")

if dataset.enable_bucket:
Expand Down Expand Up @@ -564,6 +569,7 @@ def print_info(_datasets, dataset_type: str):
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
system_prompt: {subset.system_prompt}
"""), " ")

if is_dreambooth:
Expand Down
30 changes: 17 additions & 13 deletions library/custom_offloading_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn

Expand All @@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []

# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
Expand All @@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

stream = torch.cuda.Stream()
stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
Expand All @@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))


# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)

synchronize_device()
synchronize_device(device)

# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view

synchronize_device()
synchronize_device(device)


def weighs_to_device(layer: nn.Module, device: torch.device):
Expand Down Expand Up @@ -148,13 +149,16 @@ def _wait_blocks_move(self, block_idx):
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")


# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]

class ModelOffloader(Offloader):
"""
supports forward offloading
"""

def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)

# register backward hooks
self.remove_handles = []
Expand All @@ -168,7 +172,7 @@ def __del__(self):
for handle in self.remove_handles:
handle.remove()

def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
Expand All @@ -182,7 +186,7 @@ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Opt
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1

def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")

Expand All @@ -194,7 +198,7 @@ def backward_hook(module, grad_input, grad_output):

return backward_hook

def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return

Expand All @@ -207,7 +211,7 @@ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):

for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu

synchronize_device(self.device)
clean_memory_on_device(self.device)
Expand All @@ -217,7 +221,7 @@ def wait_for_block(self, block_idx: int):
return
self._wait_blocks_move(block_idx)

def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
Expand Down
12 changes: 6 additions & 6 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,10 @@ def enable_block_swap(self, num_blocks: int, device: torch.device):
)

self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
Expand Down Expand Up @@ -1219,10 +1219,10 @@ def enable_block_swap(self, num_blocks: int, device: torch.device):
)

self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
Expand All @@ -1233,8 +1233,8 @@ def move_to_device_except_swap_blocks(self, device: torch.device):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()

self.to(device)

Expand Down
Loading
Loading