Skip to content

[Accelerate] Support get_offloaded_device for models #364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/compressed_tensors/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,26 @@ def is_module_offloaded(module: torch.nn.Module) -> bool:

def get_offloaded_device(module: torch.nn.Module) -> torch.device:
"""
Get the offload device of this module. This is determined through analysis of the
first module with parameters returned by `module.modules()`.

If this module is offloaded, return its offload device. If this module not
offloaded, return the parameter's device.

:param module: module to check
:return: device module is offloaded to onto after forward pass
:return: offload device of module
"""
if has_offloaded_params(module):
first_key = list(module._hf_hook.weights_map.keys())[0]
prefix_dataset = module._hf_hook.weights_map.dataset
return prefix_dataset[first_key].device
return next(module.parameters()).device
for submodule in module.modules():
name, param = next(submodule.named_parameters(recurse=False), (None, None))
if has_offloaded_params(submodule):
assert name is not None
return submodule._hf_hook.weights_map[name].device

elif param is not None:
return param.device

warnings.warn(f"Cannot get offload device of {module}, falling back to CPU")
return torch.device("cpu")


@check_accelerate(fallback=None)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_utils/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
disable_hf_hook,
disable_offloading,
get_execution_device,
get_offloaded_device,
has_offloaded_params,
offloaded_dispatch,
register_offload_module,
Expand Down Expand Up @@ -121,6 +122,77 @@ def forward(self, x):
assert get_execution_device(model) == torch.device("cuda:0")


@requires_gpu
@requires_accelerate()
def test_get_offloaded_device():
from accelerate import init_empty_weights

# no offloading
module = ExampleModule()
assert get_offloaded_device(module) == torch.device("cpu")

# with offloading
offloaded_dispatch(
module,
execution_device=torch.device("cpu"),
offload_device=torch.device("cuda:0"),
)
assert get_offloaded_device(module) == torch.device("cuda:0")

# in meta context
with torch.device("meta"):
module = ExampleModule()
assert get_offloaded_device(module) == torch.device("meta")

# offloaded in meta context
module = ExampleModule()
offloaded_dispatch(
module,
execution_device=torch.device("cpu"),
offload_device=torch.device("cuda:0"),
)
with torch.device("meta"):
assert get_offloaded_device(module) == torch.device("cuda:0")

# in empty weights context
with init_empty_weights():
module = ExampleModule()
assert get_offloaded_device(module) == torch.device("meta")

# offloaded in empty weights context
module = ExampleModule()
offloaded_dispatch(
module,
execution_device=torch.device("cpu"),
offload_device=torch.device("cuda:0"),
)
with init_empty_weights():
assert get_offloaded_device(module) == torch.device("cuda:0")


@requires_gpu
@requires_accelerate()
def test_get_execution_device_model():
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(1, 2)
self.b = torch.nn.Linear(2, 2, device="cuda:0")

def forward(self, x):
return self.b(self.a(x).to("cuda:0"))

model = Model()
assert get_offloaded_device(model) == torch.device("cpu")

offloaded_dispatch(
model.a,
execution_device=torch.device("cpu"),
offload_device=torch.device("cuda:0"),
)
assert get_offloaded_device(model) == torch.device("cuda:0")


@requires_accelerate()
def test_register_offload_parameter():
from accelerate import init_empty_weights
Expand Down