Skip to content

Commit aba1608

Browse files
authored
fix(backend): mps should not use non_blocking (#6549)
## Summary We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See: - pytorch/pytorch#107455 - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 Changes: - Add properties for each device on `TorchDevice` as a convenience. - Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided. - Update model patching and caching APIs to use this new utility. ## Related Issues / Discussions Fixes: #6545 ## QA Instructions For both MPS and CUDA: - Generate at least 5 images using LoRAs - Generate at least 5 images using IP Adapters ## Merge Plan We have pagination merged into `main` but aren't ready for that to be released. Once this fix is tested and merged, we will probably want to create a `v4.2.5post1` branch off the `v4.2.5` tag, cherry-pick the fix and do a release from the hotfix branch. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ @RyanJDick @lstein This feels testable but I'm not sure how. - [ ] _Documentation added / updated (if applicable)_
2 parents a0a0c57 + 14775cc commit aba1608

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

invokeai/backend/lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing_extensions import Self
1111

1212
from invokeai.backend.model_manager import BaseModelType
13+
from invokeai.backend.util.devices import TorchDevice
1314

1415
from .raw_model import RawModel
1516

@@ -521,7 +522,7 @@ def from_checkpoint(
521522
# lower memory consumption by removing already parsed layer values
522523
state_dict[layer_key].clear()
523524

524-
layer.to(device=device, dtype=dtype, non_blocking=True)
525+
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
525526
model.layers[layer_key] = layer
526527

527528
return model

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,11 @@ def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device
285285
else:
286286
new_dict: Dict[str, torch.Tensor] = {}
287287
for k, v in cache_entry.state_dict.items():
288-
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
288+
new_dict[k] = v.to(
289+
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
290+
)
289291
cache_entry.model.load_state_dict(new_dict, assign=True)
290-
cache_entry.model.to(target_device, non_blocking=True)
292+
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
291293
cache_entry.device = target_device
292294
except Exception as e: # blow away cache entry
293295
self._delete_cache_entry(cache_entry)

invokeai/backend/model_patcher.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from invokeai.backend.model_manager import AnyModel
1717
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
1818
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
19+
from invokeai.backend.util.devices import TorchDevice
1920

2021
from .lora import LoRAModelRaw
2122
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
@@ -139,12 +140,15 @@ def apply_lora(
139140
# We intentionally move to the target device first, then cast. Experimentally, this was found to
140141
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
141142
# same thing in a single call to '.to(...)'.
142-
layer.to(device=device, non_blocking=True)
143-
layer.to(dtype=torch.float32, non_blocking=True)
143+
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
144+
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
144145
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
145146
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
146147
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
147-
layer.to(device=torch.device("cpu"), non_blocking=True)
148+
layer.to(
149+
device=TorchDevice.CPU_DEVICE,
150+
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
151+
)
148152

149153
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
150154
if module.weight.shape != layer_weight.shape:
@@ -153,15 +157,17 @@ def apply_lora(
153157
layer_weight = layer_weight.reshape(module.weight.shape)
154158

155159
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
156-
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
160+
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
157161

158162
yield # wait for context manager exit
159163

160164
finally:
161165
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
162166
with torch.no_grad():
163167
for module_key, weight in original_weights.items():
164-
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
168+
model.get_submodule(module_key).weight.copy_(
169+
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
170+
)
165171

166172
@classmethod
167173
@contextmanager

invokeai/backend/util/devices.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def torch_dtype(device: torch.device) -> torch.dtype:
4242
class TorchDevice:
4343
"""Abstraction layer for torch devices."""
4444

45+
CPU_DEVICE = torch.device("cpu")
46+
CUDA_DEVICE = torch.device("cuda")
47+
MPS_DEVICE = torch.device("mps")
48+
4549
@classmethod
4650
def choose_torch_device(cls) -> torch.device:
4751
"""Return the torch.device to use for accelerated inference."""
@@ -108,3 +112,15 @@ def empty_cache(cls) -> None:
108112
@classmethod
109113
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
110114
return NAME_TO_PRECISION[precision_name]
115+
116+
@staticmethod
117+
def get_non_blocking(to_device: torch.device) -> bool:
118+
"""Return the non_blocking flag to be used when moving a tensor to a given device.
119+
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
120+
When moving _from_ MPS, we can use non-blocking operations.
121+
122+
See:
123+
- https://github.com/pytorch/pytorch/issues/107455
124+
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
125+
"""
126+
return False if to_device.type == "mps" else True

0 commit comments

Comments
 (0)