Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,21 @@ def cache_text_encoder_outputs_if_needed(
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
vae = vae.to("cpu")
unet = unet.to("cpu")
clean_memory_on_device(accelerator.device)

# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)

if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True)

with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
Expand Down Expand Up @@ -276,19 +276,19 @@ def cache_text_encoder_outputs_if_needed(
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True)
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True)
clean_memory_on_device(accelerator.device)

if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
vae = vae.to(org_vae_device, non_blocking=True)
unet = unet.to(org_unet_device, non_blocking=True)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True)
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)

def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
Expand Down Expand Up @@ -429,7 +429,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True)

return model_pred, target, timesteps, weighting

Expand Down Expand Up @@ -468,8 +468,8 @@ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL

def prepare_fp8(text_encoder, target_dtype):
Expand All @@ -488,7 +488,7 @@ def forward(hidden_states):
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
module = module.to(target_dtype, non_blocking=True)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
Expand All @@ -497,7 +497,7 @@ def forward(hidden_states):
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
prepare_fp8(text_encoder, weight_dtype)

def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
Expand Down
2 changes: 1 addition & 1 deletion library/custom_offloading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

Expand Down
14 changes: 7 additions & 7 deletions library/flux_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
return mean + std * torch.randn_like(mean, pin_memory=True)
else:
return mean

Expand Down Expand Up @@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)

args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
Expand Down Expand Up @@ -600,7 +600,7 @@ def __init__(self, dim: int):
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
return q.to(v, non_blocking=True), k.to(v, non_blocking=True)


class SelfAttention(nn.Module):
Expand Down Expand Up @@ -997,7 +997,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device):
self.double_blocks = None
self.single_blocks = None

self.to(device)
self = self.to(device, non_blocking=True)

if self.blocks_to_swap:
self.double_blocks = save_double_blocks
Expand Down Expand Up @@ -1081,8 +1081,8 @@ def forward(
img = img[:, txt.shape[1] :, ...]

if self.training and self.cpu_offload_checkpointing:
img = img.to(self.device)
vec = vec.to(self.device)
img = img.to(self.device, non_blocking=True)
vec = vec.to(self.device, non_blocking=True)

img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

Expand Down Expand Up @@ -1243,7 +1243,7 @@ def move_to_device_except_swap_blocks(self, device: torch.device):
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()

self.to(device)
self = self.to(device, non_blocking=True)

if self.blocks_to_swap:
self.double_blocks = save_double_blocks
Expand Down
2 changes: 1 addition & 1 deletion library/strategy_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]

def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
Expand Down
33 changes: 21 additions & 12 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import nullcontext
import datetime
import importlib
import json
Expand All @@ -26,6 +27,7 @@

# from concurrent.futures import ThreadPoolExecutor, as_completed

from torch.cuda import Stream
from tqdm import tqdm
from packaging.version import Version

Expand Down Expand Up @@ -1415,10 +1417,11 @@ def cache_text_encoder_outputs_common(
return

# prepare tokenizers and text encoders
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
text_encoder.to(device)
for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
te_kwargs = {}
if te_dtype is not None:
text_encoder.to(dtype=te_dtype)
te_kwargs['dtype'] = te_dtype
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)

# create batch
is_sd3 = len(tokenizers) == 1
Expand All @@ -1440,6 +1443,8 @@ def cache_text_encoder_outputs_common(
if len(batch) > 0:
batches.append(batch)

torch.cuda.synchronize()

# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
if not is_sd3:
Expand Down Expand Up @@ -3120,7 +3125,10 @@ def cache_batch_latents(
images.append(image)

img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)

s = Stream()

img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)

with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
Expand Down Expand Up @@ -3156,12 +3164,13 @@ def cache_batch_latents(
if not HIGH_VRAM:
clean_memory_on_device(vae.device)

torch.cuda.synchronize()

def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)

with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
Expand Down Expand Up @@ -5619,9 +5628,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
unet = unet.to(accelerator.device, non_blocking=True)
vae = vae.to(accelerator.device, non_blocking=True)

clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
Expand Down Expand Up @@ -6435,7 +6444,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here

org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device

# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet_wrapped)
Expand Down Expand Up @@ -6470,7 +6479,7 @@ def sample_images_common(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(distributed_state.device)
pipeline = pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

Expand Down Expand Up @@ -6521,7 +6530,7 @@ def sample_images_common(
torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
vae = vae.to(org_vae_device)

clean_memory_on_device(accelerator.device)

Expand Down
2 changes: 1 addition & 1 deletion library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu")

stream.synchronize()

Expand Down
6 changes: 3 additions & 3 deletions networks/oft.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def __init__(

if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()

# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.constraint = alpha * out_dim

self.register_buffer("alpha", torch.tensor(alpha))

self.block_size = out_dim // self.num_blocks
Expand Down
5 changes: 3 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def train(args):
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
unet = unet.to(weight_dtype)
text_encoder = text_encoder.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
if args.deepspeed:
Expand Down Expand Up @@ -335,6 +335,7 @@ def train(args):
text_encoder.train()

for step, batch in enumerate(train_dataloader):
optimizer.train()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand Down
Loading