Skip to content

Commit 3349b65

Browse files
authored
speedup when offload_mode enable (#119)
* speedup when offload_mode enable * fix flux vae to device * offload params rename as private * fix
1 parent 714ae66 commit 3349b65

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

diffsynth_engine/models/flux/flux_vae.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.float32):
5353
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
5454
with no_init_weights():
5555
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
56-
model.load_state_dict(state_dict)
56+
model.load_state_dict(state_dict, assign=True)
57+
model.to(device=device, dtype=dtype, non_blocking=True)
5758
return model
5859

5960

@@ -74,5 +75,6 @@ def __init__(self, device: str, dtype: torch.dtype = torch.float32):
7475
def from_state_dict(cls, state_dict: Dict[str, torch.Tensor], device: str, dtype: torch.dtype):
7576
with no_init_weights():
7677
model = torch.nn.utils.skip_init(cls, device=device, dtype=dtype)
77-
model.load_state_dict(state_dict)
78+
model.load_state_dict(state_dict, assign=True)
79+
model.to(device=device, dtype=dtype, non_blocking=True)
7880
return model

diffsynth_engine/pipelines/base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
self.dtype = dtype
4444
self.offload_mode = None
4545
self.model_names = []
46+
self._models_offload_params = {}
4647

4748
@classmethod
4849
def from_pretrained(
@@ -288,6 +289,10 @@ def _enable_model_cpu_offload(self):
288289
model = getattr(self, model_name)
289290
if model is not None:
290291
model.to("cpu")
292+
self._models_offload_params[model_name] = {}
293+
for name, param in model.named_parameters(recurse=True):
294+
param.data = param.data.pin_memory()
295+
self._models_offload_params[model_name][name] = param.data
291296
self.offload_mode = "cpu_offload"
292297

293298
def _enable_sequential_cpu_offload(self):
@@ -321,12 +326,14 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
321326
for model_name in self.model_names:
322327
if model_name not in load_model_names:
323328
model = getattr(self, model_name)
324-
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != "cpu":
325-
model.to("cpu")
329+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device("cpu"):
330+
param_cache = self._models_offload_params[model_name]
331+
for name, param in model.named_parameters(recurse=True):
332+
param.data = param_cache[name]
326333
# load the needed models to device
327334
for model_name in load_model_names:
328335
model = getattr(self, model_name)
329-
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
336+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != torch.device(self.device):
330337
model.to(self.device)
331338
# fresh the cuda cache
332339
empty_cache()

diffsynth_engine/pipelines/flux_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,8 @@ def from_pretrained(
604604
device=device,
605605
dtype=model_config.dit_dtype,
606606
)
607-
pipe.enable_cpu_offload(offload_mode)
607+
if offload_mode is not None:
608+
pipe.enable_cpu_offload(offload_mode)
608609
if model_config.dit_dtype == torch.float8_e4m3fn:
609610
pipe.dtype = torch.bfloat16 # running dtype
610611
pipe.enable_fp8_autocast(

diffsynth_engine/utils/offload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def add_cpu_offload_hook(module: nn.Module, device: str = "cuda", recurse: bool
1818
def _forward_pre_hook(module: nn.Module, input):
1919
offload_params = {}
2020
for name, param in module.named_parameters(recurse=recurse):
21-
offload_params[name] = param.data
21+
offload_params[name] = param.data.pin_memory()
2222
param.data = param.data.to(device=device)
2323
setattr(module, "_offload_params", offload_params)
2424
return tuple(x.to(device=device) if isinstance(x, torch.Tensor) else x for x in input)

0 commit comments

Comments
 (0)