Skip to content

Commit 14775cc

Browse files
committed
ruff format
1 parent c7562dd commit 14775cc

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,9 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
285285
else:
286286
new_dict: Dict[str, torch.Tensor] = {}
287287
for k, v in cache_entry.state_dict.items():
288-
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device))
288+
new_dict[k] = v.to(
289+
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
290+
)
289291
cache_entry.model.load_state_dict(new_dict, assign=True)
290292
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
291293
cache_entry.device = target_device

invokeai/backend/model_patcher.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ def apply_lora(
145145
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
146146
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
147147
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
148-
layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE))
148+
layer.to(
149+
device=TorchDevice.CPU_DEVICE,
150+
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
151+
)
149152

150153
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
151154
if module.weight.shape != layer_weight.shape:
@@ -162,7 +165,9 @@ def apply_lora(
162165
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
163166
with torch.no_grad():
164167
for module_key, weight in original_weights.items():
165-
model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device))
168+
model.get_submodule(module_key).weight.copy_(
169+
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
170+
)
166171

167172
@classmethod
168173
@contextmanager

0 commit comments

Comments
 (0)