Skip to content

Commit 714ae66

Browse files
support fp8 store bf16 exec (#120)
* support fp8 store bf16 exec * fix
1 parent 9ec5377 commit 714ae66

File tree

11 files changed

+137
-56
lines changed

11 files changed

+137
-56
lines changed

diffsynth_engine/models/basic/lora.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,23 @@ def apply_to(self, w: Union[nn.Linear, nn.Conv2d, nn.Parameter, torch.Tensor]):
3737
else:
3838
delta_w = self.scale * (self.alpha / self.rank) * (self.up.weight @ self.down.weight)
3939
if isinstance(w, (nn.Linear, nn.Conv2d)):
40-
delta_w = delta_w.to(device=w.weight.data.device, dtype=w.weight.data.dtype)
40+
delta_w = delta_w.to(device=w.weight.data.device, dtype=self.dtype)
41+
w_dtype = w.weight.data.dtype
42+
w.weight.data = w.weight.data.to(self.dtype)
4143
w.weight.data.add_(delta_w)
44+
w.weight.data = w.weight.data.to(w_dtype)
4245
elif isinstance(w, nn.Parameter):
43-
delta_w = delta_w.to(device=w.data.device, dtype=w.data.dtype)
46+
delta_w = delta_w.to(device=w.data.device, dtype=self.dtype)
47+
w_dtype = w.data.dtype
48+
w.data = w.data.to(self.dtype)
4449
w.data.add_(delta_w)
50+
w.data = w.data.to(w_dtype)
4551
elif isinstance(w, torch.Tensor):
46-
delta_w = delta_w.to(device=w.device, dtype=w.dtype)
52+
delta_w = delta_w.to(device=w.device, dtype=self.dtype)
53+
w_dtype = w.dtype
54+
w = w.to(self.dtype)
4755
w.add_(delta_w)
56+
w = w.to(w_dtype)
4857

4958

5059
class LoRALinear(nn.Linear):
@@ -60,8 +69,8 @@ def __init__(
6069
# LoRA
6170
self._lora_dict = OrderedDict()
6271
# Frozen LoRA
63-
self._frozen_lora_list = []
64-
self.register_buffer("_original_weight", None)
72+
self.patched_frozen_lora = False
73+
self._original_weight = None
6574

6675
@staticmethod
6776
def from_linear(linear: nn.Linear):
@@ -118,20 +127,27 @@ def add_frozen_lora(
118127
save_original_weight: bool = True,
119128
):
120129
if save_original_weight and self._original_weight is None:
121-
self._original_weight = self.weight.clone()
130+
if self.weight.dtype == torch.float8_e4m3fn:
131+
self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
132+
else:
133+
self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
122134
lora = LoRA(scale, rank, alpha, up, down, device, dtype)
123135
lora.apply_to(self)
124-
self._frozen_lora_list.append(lora)
136+
self.patched_frozen_lora = True
125137

126-
def clear(self):
127-
if self._original_weight is None and len(self._frozen_lora_list) > 0:
138+
def clear(self, release_all_cpu_memory: bool = False):
139+
if self.patched_frozen_lora and self._original_weight is None:
128140
raise RuntimeError(
129141
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
130142
)
131143
self._lora_dict.clear()
132-
self._frozen_lora_list = []
133144
if self._original_weight is not None:
134-
self.weight.data.copy_(self._original_weight)
145+
self.weight.data.copy_(
146+
self._original_weight.to(device=self.weight.data.device, dtype=self.weight.data.dtype)
147+
)
148+
if release_all_cpu_memory:
149+
del self._original_weight
150+
self.patched_frozen_lora = False
135151

136152
def forward(self, x):
137153
w_x = super().forward(x)
@@ -161,8 +177,8 @@ def __init__(
161177
# LoRA
162178
self._lora_dict = OrderedDict()
163179
# Frozen LoRA
164-
self._frozen_lora_list = []
165180
self._original_weight = None
181+
self.patched_frozen_lora = False
166182

167183
@staticmethod
168184
def from_conv2d(conv2d: nn.Conv2d):
@@ -257,21 +273,25 @@ def add_frozen_lora(
257273
save_original_weight: bool = True,
258274
):
259275
if save_original_weight and self._original_weight is None:
260-
self._original_weight = self.weight.clone()
276+
if self.weight.dtype == torch.float8_e4m3fn:
277+
self._original_weight = self.weight.to(dtype=torch.bfloat16, device="cpu", copy=True).pin_memory()
278+
else:
279+
self._original_weight = self.weight.to(device="cpu", copy=True).pin_memory()
261280
lora = self._construct_lora(name, scale, rank, alpha, up, down, device, dtype)
262281
lora.apply_to(self)
263-
self._frozen_lora_list.append(lora)
282+
self.patched_frozen_lora = True
264283

265-
def clear(self):
266-
if self._original_weight is None and len(self._frozen_lora_list) > 0:
284+
def clear(self, release_all_cpu_memory: bool = False):
285+
if self.patched_frozen_lora and self._original_weight is None:
267286
raise RuntimeError(
268287
"Current LoRALinear has patched by frozen LoRA, but original weight is not saved, so you cannot clear LoRA."
269288
)
270289
self._lora_dict.clear()
271-
self._frozen_lora_list = []
272290
if self._original_weight is not None:
273-
self.weight.copy_(self._original_weight)
274-
self._original_weight = None
291+
self.weight.copy_(self._original_weight.to(device=self.weight.device, dtype=self.weight.dtype))
292+
if release_all_cpu_memory:
293+
del self._original_weight
294+
self.patched_frozen_lora = False
275295

276296
def forward(self, x):
277297
w_x = super().forward(x)

diffsynth_engine/models/basic/transformer_helper.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
import math
43

54

65
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
@@ -83,13 +82,3 @@ def forward(self, x):
8382
if self.elementwise_affine:
8483
return norm_result * self.weight
8584
return norm_result
86-
87-
88-
class NewGELUActivation(nn.Module):
89-
"""
90-
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
91-
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
92-
"""
93-
94-
def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95-
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def forward(
435435
# addition of floating point numbers does not meet commutative law
436436
conditioning = self.time_embedder(timestep, hidden_states.dtype)
437437
if self.guidance_embedder is not None:
438-
guidance = guidance * 1000
438+
guidance = (guidance.to(torch.float32) * 1000).to(hidden_states.dtype)
439439
conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
440440
conditioning += self.pooled_text_embedder(pooled_prompt_emb)
441441
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))

diffsynth_engine/models/text_encoder/t5.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
66
from diffsynth_engine.models.basic.relative_position_emb import RelativePositionEmbedding
7-
from diffsynth_engine.models.basic.transformer_helper import RMSNorm, NewGELUActivation
7+
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
88
from diffsynth_engine.models.basic.attention import Attention
99
from diffsynth_engine.models.utils import no_init_weights
1010
from diffsynth_engine.utils.gguf import gguf_inference
@@ -21,14 +21,12 @@ def __init__(self, d_model, d_ff, dropout_rate, device: str = "cuda:0", dtype: t
2121
self.wi_1 = nn.Linear(d_model, d_ff, bias=False, device=device, dtype=dtype)
2222
self.wo = nn.Linear(d_ff, d_model, bias=False, device=device, dtype=dtype)
2323
self.dropout = nn.Dropout(dropout_rate)
24-
self.act = NewGELUActivation()
24+
self.act = nn.GELU(approximate="tanh")
2525

2626
def forward(self, hidden_states):
2727
hidden_gelu = self.act(self.wi_0(hidden_states))
2828
hidden_linear = self.wi_1(hidden_states)
2929
hidden_states = self.dropout(hidden_gelu * hidden_linear)
30-
31-
hidden_states = hidden_states.to(self.wo.weight.dtype)
3230
hidden_states = self.wo(hidden_states)
3331
return hidden_states
3432

diffsynth_engine/pipelines/base.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from PIL import Image
66
from dataclasses import dataclass
77
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
8+
from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
89
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
910
from diffsynth_engine.utils import logging
1011
from diffsynth_engine.utils.loader import load_file
@@ -100,7 +101,10 @@ def load_model_checkpoint(
100101
if not os.path.isfile(path):
101102
raise FileNotFoundError(f"{path} is not a file")
102103
elif path.endswith(".safetensors"):
103-
state_dict.update(**load_file(path, device=device))
104+
state_dict_ = load_file(path, device=device)
105+
for key, value in state_dict_.items():
106+
state_dict[key] = value.to(dtype)
107+
104108
elif path.endswith(".gguf"):
105109
state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
106110
else:
@@ -154,7 +158,7 @@ def vae_output_to_image(vae_output: torch.Tensor) -> Image.Image:
154158
@staticmethod
155159
def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
156160
generator = None if seed is None else torch.Generator(device).manual_seed(seed)
157-
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
161+
noise = torch.randn(shape, generator=generator, device=device).to(dtype)
158162
return noise
159163

160164
def encode_image(
@@ -294,6 +298,15 @@ def _enable_sequential_cpu_offload(self):
294298
enable_sequential_cpu_offload(model, self.device)
295299
self.offload_mode = "sequential_cpu_offload"
296300

301+
def enable_fp8_autocast(
302+
self, model_names: List[str], compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False
303+
):
304+
for model_name in model_names:
305+
model = getattr(self, model_name)
306+
if model is not None:
307+
enable_fp8_autocast(model, compute_dtype, use_fp8_linear)
308+
self.fp8_autocast_enabled = True
309+
297310
def load_models_to_device(self, load_model_names: List[str] | None = None):
298311
load_model_names = load_model_names if load_model_names else []
299312
# only load models to device if offload_mode is set

diffsynth_engine/pipelines/flux_image.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -526,29 +526,29 @@ def from_pretrained(
526526
model_config = (
527527
model_path_or_config
528528
if isinstance(model_path_or_config, FluxModelConfig)
529-
else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype, clip_dtype=dtype)
529+
else FluxModelConfig(dit_path=model_path_or_config, dit_dtype=dtype, t5_dtype=dtype)
530530
)
531531
if model_config.vae_path is None:
532-
model_config.vae_path = fetch_model("muse/flux_vae", revision="20241015120836", path="ae.safetensors")
532+
model_config.vae_path = fetch_model("muse/FLUX.1-dev-fp8", path="ae-bf16.safetensors")
533533

534534
if model_config.clip_path is None and load_text_encoder:
535-
model_config.clip_path = fetch_model(
536-
"muse/flux_clip_l", revision="20241209", path="clip_l_bf16.safetensors"
537-
)
535+
model_config.clip_path = fetch_model("muse/FLUX.1-dev-fp8", path="clip-bf16.safetensors")
538536
if model_config.t5_path is None and load_text_encoder:
539537
model_config.t5_path = fetch_model(
540-
"muse/google_t5_v1_1_xxl", revision="20241024105236", path="t5xxl_v1_1_bf16.safetensors"
538+
"muse/FLUX.1-dev-fp8", path=["t5-fp8-00001-of-00002.safetensors", "t5-fp8-00002-of-00002.safetensors"]
541539
)
542540

543541
logger.info(f"loading state dict from {model_config.dit_path} ...")
544-
dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=dtype)
542+
dit_state_dict = cls.load_model_checkpoint(model_config.dit_path, device="cpu", dtype=model_config.dit_dtype)
545543
logger.info(f"loading state dict from {model_config.vae_path} ...")
546-
vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=dtype)
544+
vae_state_dict = cls.load_model_checkpoint(model_config.vae_path, device="cpu", dtype=model_config.vae_dtype)
547545
if load_text_encoder:
548546
logger.info(f"loading state dict from {model_config.clip_path} ...")
549-
clip_state_dict = cls.load_model_checkpoint(model_config.clip_path, device="cpu", dtype=dtype)
547+
clip_state_dict = cls.load_model_checkpoint(
548+
model_config.clip_path, device="cpu", dtype=model_config.clip_dtype
549+
)
550550
logger.info(f"loading state dict from {model_config.t5_path} ...")
551-
t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=dtype)
551+
t5_state_dict = cls.load_model_checkpoint(model_config.t5_path, device="cpu", dtype=model_config.t5_dtype)
552552

553553
init_device = "cpu" if parallelism > 1 or offload_mode is not None else device
554554
if load_text_encoder:
@@ -602,10 +602,20 @@ def from_pretrained(
602602
vae_tile_stride=vae_tile_stride,
603603
control_type=control_type,
604604
device=device,
605-
dtype=dtype,
605+
dtype=model_config.dit_dtype,
606606
)
607-
if offload_mode is not None:
608-
pipe.enable_cpu_offload(offload_mode)
607+
pipe.enable_cpu_offload(offload_mode)
608+
if model_config.dit_dtype == torch.float8_e4m3fn:
609+
pipe.dtype = torch.bfloat16 # running dtype
610+
pipe.enable_fp8_autocast(
611+
model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
612+
)
613+
614+
if model_config.t5_dtype == torch.float8_e4m3fn:
615+
pipe.dtype = torch.bfloat16 # running dtype
616+
pipe.enable_fp8_autocast(
617+
model_names=["text_encoder_2"], compute_dtype=pipe.dtype, use_fp8_linear=model_config.use_fp8_linear
618+
)
609619

610620
if parallelism > 1:
611621
parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
@@ -803,7 +813,6 @@ def predict_noise(
803813
current_step=current_step,
804814
total_step=total_step,
805815
)
806-
807816
self.load_models_to_device(["dit"])
808817

809818
noise_pred = self.dit(

diffsynth_engine/utils/fp8_linear.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,45 @@
44
from contextlib import contextmanager
55

66

7+
def enable_fp8_autocast(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16, use_fp8_linear: bool = False):
8+
if len(list(module.children())) == 0:
9+
if len(list(module.parameters())) > 0:
10+
add_fp8_autocast_hook(module, compute_dtype)
11+
return
12+
if len(list(module.parameters(recurse=False))) > 0:
13+
add_fp8_autocast_hook(module, compute_dtype)
14+
for submodule in module.children():
15+
if isinstance(submodule, nn.Linear) and use_fp8_linear:
16+
continue
17+
18+
enable_fp8_autocast(submodule, compute_dtype, use_fp8_linear)
19+
20+
21+
def add_fp8_autocast_hook(module: nn.Module, compute_dtype: torch.dtype = torch.bfloat16):
22+
def _fp8_autocast_pre_hook(module: nn.Module, input_):
23+
for name, param in module.named_parameters():
24+
if param.dtype == torch.float8_e4m3fn:
25+
param.data = param.data.to(compute_dtype)
26+
new_inputs = []
27+
for x in input_:
28+
if isinstance(x, torch.Tensor) and x.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]:
29+
new_inputs.append(x.to(compute_dtype))
30+
else:
31+
new_inputs.append(x)
32+
return tuple(new_inputs)
33+
34+
def _fp8_autocast_hook(module: nn.Module, input_, output_):
35+
for name, param in module.named_parameters():
36+
if param.dtype == compute_dtype:
37+
param.data = param.data.to(torch.float8_e4m3fn)
38+
39+
if getattr(module, "_fp8_autocast_enabled", False):
40+
return
41+
module.register_forward_pre_hook(_fp8_autocast_pre_hook)
42+
module.register_forward_hook(_fp8_autocast_hook)
43+
setattr(module, "_fp8_autocast_enabled", True)
44+
45+
746
def enable_fp8_linear(module: nn.Module):
847
_enable_fp8_linear(module)
948
setattr(module, "fp8_linear_enabled", True)

tests/data/expect/flux/flux_lora.png

-33.7 KB
Loading
936 KB
Loading
-142 KB
Loading

0 commit comments

Comments
 (0)