Skip to content

Commit 8b4b0ff

Browse files
committed
Fix bug in CustomConv1d and CustomConv2d patch calculations.
1 parent 6fd9b0a commit 8b4b0ff

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
55
CustomModuleMixin,
66
)
7+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
8+
add_nullable_tensors,
9+
)
710

811

912
class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin):
@@ -21,9 +24,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
2124
orig_params=orig_params,
2225
device=input.device,
2326
)
24-
return self._conv_forward(
25-
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
26-
)
27+
28+
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
29+
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
30+
return self._conv_forward(input, weight, bias)
2731

2832
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
2933
weight = cast_to_device(self.weight, input.device)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
55
CustomModuleMixin,
66
)
7+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
8+
add_nullable_tensors,
9+
)
710

811

912
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
@@ -21,9 +24,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
2124
orig_params=orig_params,
2225
device=input.device,
2326
)
24-
return self._conv_forward(
25-
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
26-
)
27+
28+
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
29+
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
30+
return self._conv_forward(input, weight, bias)
2731

2832
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
2933
weight = cast_to_device(self.weight, input.device)

0 commit comments

Comments
 (0)