Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions library/lumina_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# --------------------------------------------------------

import math
import os
from typing import List, Optional, Tuple
from dataclasses import dataclass

Expand All @@ -31,6 +32,10 @@

from library import custom_offloading_utils

disable_selective_torch_compile = (
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
)

try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
Expand Down Expand Up @@ -553,7 +558,7 @@ def flash_attn(
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)


@torch.compiler.disable(reason="complex ops inside")
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
Expand Down Expand Up @@ -629,13 +634,10 @@ def __init__(
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)

# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3


@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
return self.w2(F.silu(self.w1(x))*self.w3(x))


class JointTransformerBlock(GradientCheckpointMixin):
Expand Down Expand Up @@ -701,6 +703,7 @@ def __init__(
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)

@torch.compile(disable=disable_selective_torch_compile)
def _forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -792,6 +795,7 @@ def __init__(self, hidden_size, patch_size, out_channels):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)

@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
Expand All @@ -812,6 +816,7 @@ def __init__(
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)

@torch.compiler.disable(reason="complex ops inside")
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
Expand Down Expand Up @@ -1224,6 +1229,7 @@ def forward_with_cfg(
return output

@staticmethod
@torch.compiler.disable(reason="complex ops inside")
def precompute_freqs_cis(
dim: List[int],
end: List[int],
Expand Down
19 changes: 19 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
)
parser.add_argument(
"--activation_memory_budget",
type=float,
default=None,
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
Expand Down Expand Up @@ -5506,6 +5512,19 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend

if args.activation_memory_budget is not None: # Note: 0 is a valid value.
if 0 <= args.activation_memory_budget <= 1:
logger.info(
f"set torch compile activation memory budget to {args.activation_memory_budget}"
)
torch._functorch.config.activation_memory_budget = ( # type: ignore
args.activation_memory_budget
)
else:
raise ValueError(
"activation_memory_budget must be between 0 and 1 (inclusive)"
)

kwargs_handlers = [
(
InitProcessGroupKwargs(
Expand Down