Skip to content

Commit e09e716

Browse files
committed
make sure to sync current stream before overwriting with pinned params
not doing so will lead to erroneous computations on the GPU and cause bad results
1 parent 8c63bf5 commit e09e716

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def onload_(self):
8888
def offload_(self):
8989
r"""Offloads the group of modules to the offload_device."""
9090
if self.stream is not None:
91+
torch.cuda.current_stream().synchronize()
9192
for group_module in self.modules:
9293
for param in group_module.parameters():
9394
param.data = self.cpu_param_dict[param]
@@ -427,7 +428,7 @@ def _apply_group_offloading_leaf_level(
427428
cpu_param_dict = {param: param.data for param in module.parameters()}
428429

429430
# Create module groups for leaf modules and apply group offloading hooks
430-
for name, submodule in module.named_modules():
431+
for submodule in module.modules():
431432
if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
432433
continue
433434
group = ModuleGroup(

src/diffusers/hooks/hooks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,7 @@ def new_forward(module, *args, **kwargs):
151151
# return hook.post_forward(module, output)
152152

153153
new_forward = create_new_forward(fn_ref)
154-
new_forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward)
155-
self._module_ref.forward = new_forward
154+
self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward)
156155

157156
self.hooks[name] = hook
158157
self._hook_order.append(name)

0 commit comments

Comments
 (0)