Skip to content

Commit 10e4e55

Browse files
authored
[Accelerate] Expand get_execution_device to support models (#363)
* expand to support models Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * reduce assumptions Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * better model testing Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add mark for tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update docstring Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 55771b2 commit 10e4e55

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,24 @@ def update_parameter_data(
171171

172172
def get_execution_device(module: torch.nn.Module) -> torch.device:
173173
"""
174-
Get the device which inputs should be moved to before module execution
174+
Get the device which inputs should be moved to before module execution.
175+
Assume that modules execute in the same order as returned by `model.modules()`
175176
176177
:param module: module to check, may be offloaded
177178
:return: onload device of module
178179
"""
179-
if has_offloaded_params(module):
180-
return module._hf_hook.execution_device
180+
for module in module.modules():
181+
if has_offloaded_params(module):
182+
return module._hf_hook.execution_device
181183

182-
first_param = next(module.parameters(), None)
183-
if first_param is None:
184-
warnings.warn(
185-
f"Unable able to infer execution device of {module}, falling back to CPU"
186-
)
187-
return torch.device("cpu")
184+
param = next(module.parameters(recurse=False), None)
185+
if param is not None:
186+
return param.device
188187

189-
return first_param.device
188+
warnings.warn(
189+
f"Unable able to get execution device of {module}, falling back to CPU"
190+
)
191+
return torch.device("cpu")
190192

191193

192194
def register_offload_parameter(

tests/test_utils/test_offload.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,25 @@ def test_get_execution_device():
102102
assert get_execution_device(module) == torch.device("cuda:0")
103103

104104

105+
@requires_gpu
106+
@requires_accelerate()
107+
def test_get_execution_device_model():
108+
class Model(torch.nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
self.a = torch.nn.Linear(1, 2)
112+
self.b = torch.nn.Linear(2, 2, device="cuda:0")
113+
114+
def forward(self, x):
115+
return self.b(self.a(x).to("cuda:0"))
116+
117+
model = Model()
118+
assert get_execution_device(model) == torch.device("cpu")
119+
120+
offloaded_dispatch(model.a, torch.device("cuda:0"))
121+
assert get_execution_device(model) == torch.device("cuda:0")
122+
123+
105124
@requires_accelerate()
106125
def test_register_offload_parameter():
107126
from accelerate import init_empty_weights

0 commit comments

Comments
 (0)