Skip to content

Commit f8ab414

Browse files
committed
Add CachedModelOnlyFullLoad to mirror the CachedModelWithPartialLoad for models that cannot or should not be partially loaded.
1 parent c6795a1 commit f8ab414

File tree

4 files changed

+247
-30
lines changed

4 files changed

+247
-30
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any
2+
3+
import torch
4+
5+
6+
class CachedModelOnlyFullLoad:
7+
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
8+
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
9+
MPS memory, etc.
10+
"""
11+
12+
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
13+
"""Initialize a CachedModelOnlyFullLoad.
14+
Args:
15+
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
16+
compute_device (torch.device): The compute device to move the model to.
17+
total_bytes (int): The total size (in bytes) of all the weights in the model.
18+
"""
19+
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
20+
self._model = model
21+
self._compute_device = compute_device
22+
self._offload_device = torch.device("cpu")
23+
24+
# A CPU read-only copy of the model's state dict.
25+
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
26+
if isinstance(model, torch.nn.Module):
27+
self._cpu_state_dict = model.state_dict()
28+
29+
self._total_bytes = total_bytes
30+
self._is_in_vram = False
31+
32+
@property
33+
def model(self) -> torch.nn.Module:
34+
return self._model
35+
36+
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
37+
"""Get a read-only copy of the model's state dict in RAM."""
38+
# TODO(ryand): Document this better.
39+
return self._cpu_state_dict
40+
41+
def total_bytes(self) -> int:
42+
"""Get the total size (in bytes) of all the weights in the model."""
43+
return self._total_bytes
44+
45+
def cur_vram_bytes(self) -> int:
46+
"""Get the size (in bytes) of the weights that are currently in VRAM."""
47+
if self._is_in_vram:
48+
return self._total_bytes
49+
else:
50+
return 0
51+
52+
def is_in_vram(self) -> bool:
53+
"""Return true if the model is currently in VRAM."""
54+
return self._is_in_vram
55+
56+
def full_load_to_vram(self) -> int:
57+
"""Load all weights into VRAM (if supported by the model).
58+
Returns:
59+
The number of bytes loaded into VRAM.
60+
"""
61+
if self._is_in_vram:
62+
# Already in VRAM.
63+
return 0
64+
65+
if not hasattr(self._model, "to"):
66+
# Model doesn't support moving to a device.
67+
return 0
68+
69+
if self._cpu_state_dict is not None:
70+
new_state_dict: dict[str, torch.Tensor] = {}
71+
for k, v in self._cpu_state_dict.items():
72+
new_state_dict[k] = v.to(self._compute_device, copy=True)
73+
self._model.load_state_dict(new_state_dict, assign=True)
74+
self._model.to(self._compute_device)
75+
76+
self._is_in_vram = True
77+
return self._total_bytes
78+
79+
def full_unload_from_vram(self) -> int:
80+
"""Unload all weights from VRAM.
81+
Returns:
82+
The number of bytes unloaded from VRAM.
83+
"""
84+
if not self._is_in_vram:
85+
# Already in RAM.
86+
return 0
87+
88+
if self._cpu_state_dict is not None:
89+
self._model.load_state_dict(self._cpu_state_dict, assign=True)
90+
self._model.to(self._offload_device)
91+
92+
self._is_in_vram = False
93+
return self._total_bytes
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import torch
2+
3+
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
4+
CachedModelOnlyFullLoad,
5+
)
6+
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
7+
8+
9+
class NonTorchModel:
10+
"""A model that does not sub-class torch.nn.Module."""
11+
12+
def __init__(self):
13+
self.linear = torch.nn.Linear(10, 32)
14+
15+
def run_inference(self, x: torch.Tensor) -> torch.Tensor:
16+
return self.linear(x)
17+
18+
19+
@parameterize_mps_and_cuda
20+
def test_cached_model_total_bytes(device: str):
21+
model = DummyModule()
22+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
23+
assert cached_model.total_bytes() == 100
24+
25+
26+
@parameterize_mps_and_cuda
27+
def test_cached_model_is_in_vram(device: str):
28+
model = DummyModule()
29+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
30+
assert not cached_model.is_in_vram()
31+
assert cached_model.cur_vram_bytes() == 0
32+
33+
cached_model.full_load_to_vram()
34+
assert cached_model.is_in_vram()
35+
assert cached_model.cur_vram_bytes() == 100
36+
37+
cached_model.full_unload_from_vram()
38+
assert not cached_model.is_in_vram()
39+
assert cached_model.cur_vram_bytes() == 0
40+
41+
42+
@parameterize_mps_and_cuda
43+
def test_cached_model_full_load_and_unload(device: str):
44+
model = DummyModule()
45+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
46+
assert cached_model.full_load_to_vram() == 100
47+
assert cached_model.is_in_vram()
48+
assert all(p.device.type == device for p in cached_model.model.parameters())
49+
50+
assert cached_model.full_unload_from_vram() == 100
51+
assert not cached_model.is_in_vram()
52+
assert all(p.device.type == "cpu" for p in cached_model.model.parameters())
53+
54+
55+
@parameterize_mps_and_cuda
56+
def test_cached_model_get_cpu_state_dict(device: str):
57+
model = DummyModule()
58+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
59+
assert not cached_model.is_in_vram()
60+
61+
# The CPU state dict can be accessed and has the expected properties.
62+
cpu_state_dict = cached_model.get_cpu_state_dict()
63+
assert cpu_state_dict is not None
64+
assert len(cpu_state_dict) == len(model.state_dict())
65+
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
66+
67+
# Full load the model into VRAM.
68+
cached_model.full_load_to_vram()
69+
assert cached_model.is_in_vram()
70+
71+
# The CPU state dict is still available, and still on the CPU.
72+
cpu_state_dict = cached_model.get_cpu_state_dict()
73+
assert cpu_state_dict is not None
74+
assert len(cpu_state_dict) == len(model.state_dict())
75+
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
76+
77+
78+
@parameterize_mps_and_cuda
79+
def test_cached_model_full_load_and_inference(device: str):
80+
model = DummyModule()
81+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
82+
assert not cached_model.is_in_vram()
83+
84+
# Run inference on the CPU.
85+
x = torch.randn(1, 10)
86+
output1 = model(x)
87+
assert output1.device.type == "cpu"
88+
89+
# Full load the model into VRAM.
90+
cached_model.full_load_to_vram()
91+
assert cached_model.is_in_vram()
92+
93+
# Run inference on the GPU.
94+
output2 = model(x.to(device))
95+
assert output2.device.type == device
96+
97+
# The outputs should be the same for both runs.
98+
assert torch.allclose(output1, output2.to("cpu"))
99+
100+
101+
@parameterize_mps_and_cuda
102+
def test_non_torch_model(device: str):
103+
model = NonTorchModel()
104+
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
105+
assert not cached_model.is_in_vram()
106+
107+
# The model does not have a CPU state dict.
108+
assert cached_model.get_cpu_state_dict() is None
109+
110+
# Attempting to load the model into VRAM should have no effect.
111+
cached_model.full_load_to_vram()
112+
assert not cached_model.is_in_vram()
113+
assert cached_model.cur_vram_bytes() == 0
114+
115+
# Attempting to unload the model from VRAM should have no effect.
116+
cached_model.full_unload_from_vram()
117+
assert not cached_model.is_in_vram()
118+
assert cached_model.cur_vram_bytes() == 0
119+
120+
# Running inference on the CPU should work.
121+
output1 = model.run_inference(torch.randn(1, 10))
122+
assert output1.device.type == "cpu"

tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,13 @@
11
import itertools
22

3-
import pytest
43
import torch
54

65
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
76
CachedModelWithPartialLoad,
87
)
98
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
109
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
11-
12-
13-
class DummyModule(torch.nn.Module):
14-
def __init__(self):
15-
super().__init__()
16-
self.linear1 = torch.nn.Linear(10, 32)
17-
self.linear2 = torch.nn.Linear(32, 64)
18-
self.register_buffer("buffer1", torch.ones(64))
19-
# Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled
20-
# correctly by the partial loading code.
21-
self.register_buffer("buffer2", torch.ones(64), persistent=False)
22-
23-
def forward(self, x: torch.Tensor) -> torch.Tensor:
24-
x = self.linear1(x)
25-
x = self.linear2(x)
26-
x = x + self.buffer1
27-
x = x + self.buffer2
28-
return x
29-
30-
31-
parameterize_mps_and_cuda = pytest.mark.parametrize(
32-
("device"),
33-
[
34-
pytest.param(
35-
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
36-
),
37-
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
38-
],
39-
)
10+
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
4011

4112

4213
@parameterize_mps_and_cuda
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
import torch
3+
4+
5+
class DummyModule(torch.nn.Module):
6+
def __init__(self):
7+
super().__init__()
8+
self.linear1 = torch.nn.Linear(10, 32)
9+
self.linear2 = torch.nn.Linear(32, 64)
10+
self.register_buffer("buffer1", torch.ones(64))
11+
# Non-persistent buffers are not included in the state dict. We need to make sure that this case is handled
12+
# correctly by the partial loading code.
13+
self.register_buffer("buffer2", torch.ones(64), persistent=False)
14+
15+
def forward(self, x: torch.Tensor) -> torch.Tensor:
16+
x = self.linear1(x)
17+
x = self.linear2(x)
18+
x = x + self.buffer1
19+
x = x + self.buffer2
20+
return x
21+
22+
23+
parameterize_mps_and_cuda = pytest.mark.parametrize(
24+
("device"),
25+
[
26+
pytest.param(
27+
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
28+
),
29+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
30+
],
31+
)

0 commit comments

Comments
 (0)