Skip to content

Commit 3fb2844

Browse files
authored
[Accelerate] Fix offloaded_dispatch, implement disable_offloading (#355)
* fix offloaded_dispatch, implement disable_offloading Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * update params Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * small speedup Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 852b1fa commit 3fb2844

File tree

2 files changed

+95
-13
lines changed

2 files changed

+95
-13
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
import warnings
3232
from functools import wraps
3333
from operator import attrgetter
34-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
34+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
3535

3636
import torch
37+
from compressed_tensors.utils import patch_attr
3738

3839

3940
try:
@@ -83,6 +84,7 @@
8384
"register_offload_module",
8485
"delete_offload_module",
8586
"offloaded_dispatch",
87+
"disable_offloading",
8688
]
8789

8890

@@ -214,7 +216,7 @@ def register_offload_parameter(
214216
def update_offload_parameter(
215217
module: torch.nn.Module,
216218
name: str,
217-
data: Optional[torch.Tensor],
219+
data: torch.Tensor,
218220
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
219221
):
220222
"""
@@ -227,15 +229,15 @@ def update_offload_parameter(
227229
:param offload_device: device on which weight will be offloaded to. If None is
228230
provided, then infer device from parameters on module
229231
"""
230-
param = getattr(module, name)
232+
param: torch.nn.Parameter = getattr(module, name)
231233
if param.data.shape != data.shape:
232234
warnings.warn(
233235
f"Shape of parameter being updated {param.data.shape} does not match shape "
234236
f"of update data {data.shape}"
235237
)
236238

237239
# copy data into onloaded parameter if applicable
238-
if param.device != torch.device("meta"):
240+
if param.device != torch.device("meta") and data is not param.data:
239241
param.data.copy_(data)
240242

241243
# update offload dict
@@ -501,7 +503,9 @@ def offloaded_dispatch(
501503
raise NotImplementedError("Disk offloading is not currently supported")
502504

503505
# create weights map
504-
weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu")
506+
state_dict = module.state_dict()
507+
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
508+
weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
505509

506510
# create tied params map
507511
tied_params = find_tied_parameters(module)
@@ -522,6 +526,36 @@ def offloaded_dispatch(
522526
return module
523527

524528

529+
@contextlib.contextmanager
530+
def disable_offloading():
531+
"""
532+
Keep modules onloaded and disable offloading until this context exits.
533+
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
534+
"""
535+
original_pre_forward = AlignDevicesHook.pre_forward
536+
onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
537+
538+
# onload once and disable any future onloading/offloading steps
539+
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
540+
ret = original_pre_forward(self, module, *args, **kwargs)
541+
if module not in onloaded_modules:
542+
onloaded_modules[module] = (self, self.offload)
543+
self.offload = False
544+
return ret
545+
546+
# use the patched pre_forward function within the context
547+
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
548+
yield
549+
550+
# manually offload all modules that were onloaded
551+
# update any parameters which may have changed
552+
for module, (hook, offload) in onloaded_modules.items():
553+
hook.offload = offload
554+
for name, param in module.named_parameters():
555+
update_offload_parameter(module, name, param.data)
556+
hook.post_forward(module, None)
557+
558+
525559
""" Upstreamed Functions """
526560

527561

tests/test_utils/test_offload.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
delete_offload_module,
2020
delete_offload_parameter,
2121
disable_hf_hook,
22+
disable_offloading,
2223
get_execution_device,
2324
has_offloaded_params,
2425
offloaded_dispatch,
@@ -397,29 +398,37 @@ def test_delete_offload_module(exec_device):
397398

398399
@requires_gpu
399400
@requires_accelerate()
400-
@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")])
401-
def test_offloaded_dispatch(exec_device):
401+
@pytest.mark.parametrize(
402+
"exec_device,offload_device",
403+
[
404+
(torch.device("cpu"), torch.device("cpu")),
405+
(torch.device("cpu"), torch.device("cuda:0")),
406+
(torch.device("cuda:0"), torch.device("cpu")),
407+
(torch.device("cuda:0"), torch.device("cuda:0")),
408+
],
409+
)
410+
def test_offloaded_dispatch(exec_device, offload_device):
402411
# single module
403-
module = torch.nn.Linear(1, 2)
404-
module = offloaded_dispatch(module, exec_device)
412+
module = torch.nn.Linear(1, 2, device=offload_device)
413+
module = offloaded_dispatch(module, exec_device, offload_device)
405414
assert has_offloaded_params(module)
406415
assert module._hf_hook.offload
407416
assert module.weight.device == torch.device("meta")
408-
assert "weight" in module._hf_hook.weights_map
417+
assert module._hf_hook.weights_map["weight"].device == offload_device
409418
assert module._hf_hook.tied_params_map is not None
410419

411420
# can run
412421
module(torch.empty(1, device=exec_device))
413422

414423
# model
415424
model = ExampleModel()
416-
model = offloaded_dispatch(model, exec_device)
425+
model = offloaded_dispatch(model, exec_device, offload_device)
417426
assert not has_offloaded_params(model)
418427

419428
assert has_offloaded_params(model.linear)
420429
assert model.linear._hf_hook.offload
421430
assert model.linear.weight.device == torch.device("meta")
422-
assert "weight" in model.linear._hf_hook.weights_map
431+
assert model.linear._hf_hook.weights_map["weight"].device == offload_device
423432
assert model.linear._hf_hook.tied_params_map is not None
424433

425434
# can run
@@ -429,4 +438,43 @@ def test_offloaded_dispatch(exec_device):
429438
parameter = torch.nn.Parameter(torch.tensor(1.0))
430439
register_offload_parameter(module, "new_param", parameter)
431440
assert module.new_param.device == torch.device("meta")
432-
assert module._hf_hook.weights_map["new_param"].device == torch.device("cpu")
441+
assert module._hf_hook.weights_map["new_param"].device == offload_device
442+
443+
444+
@requires_gpu
445+
@requires_accelerate()
446+
@pytest.mark.parametrize(
447+
"exec_device,offload_device",
448+
[
449+
(torch.device("cpu"), torch.device("cpu")),
450+
(torch.device("cpu"), torch.device("cuda:0")),
451+
(torch.device("cuda:0"), torch.device("cpu")),
452+
(torch.device("cuda:0"), torch.device("cuda:0")),
453+
],
454+
)
455+
def test_disable_offloading(exec_device, offload_device):
456+
module = torch.nn.Linear(1, 2, device=exec_device)
457+
458+
# non-offloaded modules are unaffected
459+
with disable_offloading():
460+
output = module(torch.empty(1, device=exec_device))
461+
assert module.weight.device == exec_device
462+
assert output.device == exec_device
463+
464+
# offloaded modules stay on device until context exit
465+
offloaded_dispatch(module, exec_device, offload_device)
466+
assert module.weight.device == torch.device("meta")
467+
assert module._hf_hook.weights_map["weight"].device == offload_device
468+
469+
with disable_offloading():
470+
assert module.weight.device == torch.device("meta")
471+
output = module(torch.empty(1, device=exec_device))
472+
assert module.weight.device == exec_device
473+
assert output.device == exec_device
474+
475+
output = module(torch.empty(1, device=exec_device))
476+
assert module.weight.device == exec_device
477+
assert output.device == exec_device
478+
479+
assert module.weight.device == torch.device("meta")
480+
assert module._hf_hook.weights_map["weight"].device == offload_device

0 commit comments

Comments
 (0)