Skip to content

Commit 84f5cbd

Browse files
author
Lincoln Stein
committed
make choose_torch_dtype() usable outside an invocation context
1 parent edac01d commit 84f5cbd

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

invokeai/app/services/model_install/model_install_default.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from invokeai.backend.model_manager.probe import ModelProbe
4444
from invokeai.backend.model_manager.search import ModelSearch
4545
from invokeai.backend.util import InvokeAILogger
46+
from invokeai.backend.util.devices import TorchDevice
4647

4748
from .model_install_base import (
4849
MODEL_SOURCE_TO_TYPE_MAP,
@@ -636,7 +637,7 @@ def _next_id(self) -> int:
636637

637638
def _guess_variant(self) -> Optional[ModelRepoVariant]:
638639
"""Guess the best HuggingFace variant type to download."""
639-
precision = torch.float16 if self._app_config.precision == "auto" else torch.dtype(self._app_config.precision)
640+
precision = TorchDevice.choose_torch_dtype()
640641
return ModelRepoVariant.FP16 if precision == torch.float16 else None
641642

642643
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:

invokeai/backend/util/devices.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def choose_torch_device(cls) -> torch.device:
6060
"""Return the torch.device to use for accelerated inference."""
6161
if cls._model_cache:
6262
return cls._model_cache.get_execution_device()
63+
else:
64+
return cls._choose_device()
65+
66+
@classmethod
67+
def _choose_device(cls) -> torch.device:
6368
app_config = get_config()
6469
if app_config.device != "auto":
6570
device = torch.device(app_config.device)
@@ -82,8 +87,8 @@ def execution_devices(cls) -> Set[torch.device]:
8287
@classmethod
8388
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
8489
"""Return the precision to use for accelerated inference."""
85-
device = device or cls.choose_torch_device()
8690
config = get_config()
91+
device = device or cls._choose_device()
8792
if device.type == "cuda" and torch.cuda.is_available():
8893
device_name = torch.cuda.get_device_name(device)
8994
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:

0 commit comments

Comments
 (0)