Skip to content

Commit 7dc182b

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/transform_construct_cache_device
2 parents 8e36540 + 98a0cd7 commit 7dc182b

File tree

9 files changed

+382
-142
lines changed

9 files changed

+382
-142
lines changed

.github/actions/test/action.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,23 @@ runs:
2222
name: compressed
2323
extra: "[dev,accelerate]"
2424

25+
- name: clean up
26+
run: |
27+
echo "cleaning up disk space..."
28+
find . -type f -name '*.whl' -exec rm -rf {} \;
29+
python -m pip cache purge
30+
sudo rm -rf /usr/local/.ghcup
31+
sudo rm -rf /opt/hostedtoolcache/CodeQL
32+
sudo rm -rf /usr/local/lib/android/sdk/ndk
33+
sudo rm -rf /usr/share/dotnet
34+
sudo rm -rf /opt/ghc
35+
sudo rm -rf /usr/local/share/boost
36+
if [[ "$(cat /etc/issue)" =~ Ubuntu ]]; then
37+
sudo apt-get clean
38+
fi
39+
df -h
40+
shell: bash
41+
2542
- name: test
2643
id: test
2744
run: |

.github/workflows/report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
shell: bash
121121

122122
- name: report to reportportal
123-
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.15.0
123+
uses: neuralmagic/nm-actions/actions/reportportal_submit_execution_results@v1.22.0
124124
with:
125125
droute_username: ${{ secrets.DROUTE_USERNAME }}
126126
droute_password: ${{ secrets.DROUTE_PASSWORD }}

.github/workflows/test.yml

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ on:
2222
whl:
2323
description: "whl to test (variable appears late binding so unusable outside 'download artifact')"
2424
type: string
25-
required: true
25+
run_id:
26+
description: run id of the BUILD job that generated the assets
27+
type: string
2628

2729
# makes workflow manually callable
2830
workflow_dispatch:
@@ -44,9 +46,11 @@ on:
4446
type: string
4547
required: true
4648
whl:
47-
description: "whl to test (variable appears late binding so unusable outside 'download artifact')"
49+
description: "whl to test (provide either whl or run_id)"
50+
type: string
51+
run_id:
52+
description: run id of the BUILD job that generated the assets
4853
type: string
49-
required: true
5054

5155
jobs:
5256

@@ -87,11 +91,33 @@ jobs:
8791

8892
- name: download whl
8993
id: download
94+
if: ${{ inputs.whl != '' }}
9095
uses: actions/download-artifact@v4
9196
with:
9297
name: ${{ inputs.whl }}
9398
path: ${{ inputs.whl }}
9499

100+
# GCP
101+
- name: 'Authenticate to Google Cloud'
102+
id: auth
103+
uses: google-github-actions/auth@v2.1.3
104+
with:
105+
project_id: ${{ secrets.GCP_PROJECT }}
106+
workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
107+
service_account: ${{ secrets.GCP_GHA_SA }}
108+
109+
- name: 'Set up Cloud SDK'
110+
uses: 'google-github-actions/setup-gcloud@v2'
111+
with:
112+
version: '>= 473.0.0'
113+
114+
- name: download assets
115+
if: ${{ inputs.run_id != '' }}
116+
uses: neuralmagic/nm-actions/actions/gcp-download-assets@v1.1.0
117+
with:
118+
bucket_source: ${{ secrets.GCP_BUILD_ML_ASSETS2 }}
119+
run_id: ${{ inputs.run_id }}
120+
95121
- name: run tests
96122
id: test
97123
uses: ./.github/actions/test/
@@ -109,20 +135,6 @@ jobs:
109135
whl: ${{ inputs.whl }}
110136
test_status: ${{ steps.test.outputs.status }}
111137

112-
# GCP
113-
- name: 'Authenticate to Google Cloud'
114-
id: auth
115-
uses: google-github-actions/auth@v2.1.3
116-
with:
117-
project_id: ${{ secrets.GCP_PROJECT }}
118-
workload_identity_provider: ${{ secrets.GCP_WORKLOAD_IDENTITY_PROVIDER }}
119-
service_account: ${{ secrets.GCP_GHA_SA }}
120-
121-
- name: 'Set up Cloud SDK'
122-
uses: 'google-github-actions/setup-gcloud@v2'
123-
with:
124-
version: '>= 473.0.0'
125-
126138
- name: copy results to GCP
127139
run: |
128140
gcloud storage cp test-results/report.xml ${{ secrets.GCP_BUILD_ML_ASSETS2 }}/${{ github.run_id }}/test-results/report-${{ inputs.test_label }}.xml

.github/workflows/trigger-all.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ jobs:
3232
wf_category: ${{ inputs.wf_category || 'NIGHTLY' }}
3333
gitref: ${{ inputs.gitref || 'main' }}
3434
push_to_pypi: ${{ (github.event.schedule == '30 0 * * *') || inputs.push_to_pypi || false }}
35-
test_configs: '[{"python":"3.11.4","label":"ubuntu-22.04","timeout":"40"},
36-
{"python":"3.10.12","label":"ubuntu-24.04","timeout":"40"},
35+
test_configs: '[{"python":"3.11.4","label":"ubuntu-24.04","timeout":"40"},
36+
{"python":"3.10.12","label":"ubuntu-22.04","timeout":"40"},
3737
{"python":"3.9.17","label":"k8s-h100-solo","timeout":"40"},
3838
{"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'
3939

src/compressed_tensors/utils/offload.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
import warnings
3232
from functools import wraps
3333
from operator import attrgetter
34-
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
34+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Tuple, Union
3535

3636
import torch
37+
from compressed_tensors.utils import patch_attr
3738

3839

3940
try:
@@ -83,6 +84,8 @@
8384
"register_offload_module",
8485
"delete_offload_module",
8586
"offloaded_dispatch",
87+
"disable_offloading",
88+
"remove_dispatch",
8689
]
8790

8891

@@ -168,22 +171,22 @@ def update_parameter_data(
168171

169172
def get_execution_device(module: torch.nn.Module) -> torch.device:
170173
"""
171-
Get the device which inputs should be moved to before module execution
174+
Get the device which inputs should be moved to before module execution.
175+
Assume that modules execute in the same order as returned by `model.modules()`
172176
173177
:param module: module to check, may be offloaded
174178
:return: onload device of module
175179
"""
176-
if has_offloaded_params(module):
177-
return module._hf_hook.execution_device
180+
for submodule in module.modules():
181+
if has_offloaded_params(submodule):
182+
return submodule._hf_hook.execution_device
178183

179-
first_param = next(module.parameters(), None)
180-
if first_param is None:
181-
warnings.warn(
182-
f"Unable able to infer execution device of {module}, falling back to CPU"
183-
)
184-
return torch.device("cpu")
184+
param = next(submodule.parameters(recurse=False), None)
185+
if param is not None:
186+
return param.device
185187

186-
return first_param.device
188+
warnings.warn(f"Unable to get execution device of {module}, falling back to CPU")
189+
return torch.device("cpu")
187190

188191

189192
def register_offload_parameter(
@@ -204,17 +207,32 @@ def register_offload_parameter(
204207
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
205208
module.register_parameter(name, parameter)
206209

210+
# do everything AlignDevicesHook.init_hook does
211+
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L281
207212
if has_offloaded_params(module):
208-
weights_map = module._hf_hook.weights_map
209-
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
213+
hook: AlignDevicesHook = module._hf_hook
214+
assert hook.weights_map is not None
215+
216+
# append to original_devices
217+
hook.original_devices[name] = parameter.device
218+
219+
# append to weights map
220+
offload_to_weights_map(hook.weights_map, name, parameter.data, offload_device)
221+
222+
# append to tied_params_map
223+
offloaded = hook.weights_map[name]
224+
if hook.tied_params_map is not None:
225+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
226+
227+
# perform offloading
210228
if not has_onload:
211229
set_module_tensor_to_device(module, name, "meta")
212230

213231

214232
def update_offload_parameter(
215233
module: torch.nn.Module,
216234
name: str,
217-
data: Optional[torch.Tensor],
235+
data: torch.Tensor,
218236
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
219237
):
220238
"""
@@ -227,15 +245,15 @@ def update_offload_parameter(
227245
:param offload_device: device on which weight will be offloaded to. If None is
228246
provided, then infer device from parameters on module
229247
"""
230-
param = getattr(module, name)
248+
param: torch.nn.Parameter = getattr(module, name)
231249
if param.data.shape != data.shape:
232250
warnings.warn(
233251
f"Shape of parameter being updated {param.data.shape} does not match shape "
234252
f"of update data {data.shape}"
235253
)
236254

237255
# copy data into onloaded parameter if applicable
238-
if param.device != torch.device("meta"):
256+
if param.device != torch.device("meta") and data is not param.data:
239257
param.data.copy_(data)
240258

241259
# update offload dict
@@ -420,7 +438,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
420438
hook: AlignDevicesHook = base._hf_hook
421439
assert hook.offload
422440
assert hook.weights_map is not None
423-
assert hook.tied_params_map is not None
424441

425442
# offloading kwargs for submodule
426443
place_submodules = False
@@ -435,7 +452,8 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
435452
module, include_buffers=offload_buffers, recurse=place_submodules
436453
):
437454
offloaded = param.to(offload_device)
438-
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
455+
if hook.tied_params_map is not None:
456+
hook.tied_params_map[offloaded.data_ptr()] = {} # (1)
439457
offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded)
440458

441459
# if the parent places submodules, offload here
@@ -463,9 +481,6 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M
463481

464482
base.register_module(name, module)
465483

466-
# (1): Since we cannot know which pointers are shared when we add parameters in an
467-
# online way, assume that all pointers are shared. This comes at no runtime cost
468-
469484

470485
def delete_offload_module(base: torch.nn.Module, name: str):
471486
"""
@@ -500,8 +515,13 @@ def offloaded_dispatch(
500515
if offload_device == "disk":
501516
raise NotImplementedError("Disk offloading is not currently supported")
502517

518+
# remove any existing hooks
519+
remove_dispatch(module)
520+
503521
# create weights map
504-
weights_map = OffloadedWeightsLoader(state_dict=module.state_dict(), device="cpu")
522+
state_dict = module.state_dict()
523+
state_dict = {key: val.to(offload_device) for key, val in state_dict.items()}
524+
weights_map = OffloadedWeightsLoader(state_dict=state_dict, device=offload_device)
505525

506526
# create tied params map
507527
tied_params = find_tied_parameters(module)
@@ -519,9 +539,66 @@ def offloaded_dispatch(
519539
weights_map=weights_map,
520540
tied_params_map=tied_params_map,
521541
)
542+
543+
# when saving a model, `PretrainedModel.save_pretrained` will only
544+
# onload weights if the following requirements are met
545+
# if (
546+
# hasattr(self, "hf_device_map")
547+
# and len(set(self.hf_device_map.values())) > 1
548+
# and ("cpu" in self.hf_device_map.values()
549+
# or "disk" in self.hf_device_map.values())
550+
# ):
551+
# because this function always offloads, disregard actual devices and
552+
# always use `cpu` and `cuda:0` to guarantee this condition passes
553+
setattr(module, "hf_device_map", {"fake_offload": "cpu", "fake_exec": "cuda:0"})
554+
522555
return module
523556

524557

558+
def remove_dispatch(module: torch.nn.Module) -> torch.nn.Module:
559+
"""
560+
Remove any existing dispatches from module
561+
562+
:param module: module which may be dispatched with hf hooks
563+
:return: module without dispatch
564+
"""
565+
remove_hook_from_module(module, recurse=True)
566+
if hasattr(module, "hf_device_map"):
567+
delattr(module, "hf_device_map")
568+
569+
return module
570+
571+
572+
@contextlib.contextmanager
573+
def disable_offloading():
574+
"""
575+
Keep modules onloaded and disable offloading until this context exits.
576+
Affects modules which have been hooked with accelerate's `AlignDevicesHook`
577+
"""
578+
original_pre_forward = AlignDevicesHook.pre_forward
579+
onloaded_modules: Dict[torch.nn.Module, Tuple[AlignDevicesHook, bool]] = dict()
580+
581+
# onload once and disable any future onloading/offloading steps
582+
def keep_onload_pre_forward(self: AlignDevicesHook, module, *args, **kwargs):
583+
ret = original_pre_forward(self, module, *args, **kwargs)
584+
if module not in onloaded_modules:
585+
onloaded_modules[module] = (self, self.offload)
586+
self.offload = False
587+
return ret
588+
589+
# use the patched pre_forward function within the context
590+
with patch_attr(AlignDevicesHook, "pre_forward", keep_onload_pre_forward):
591+
yield
592+
593+
# manually offload all modules that were onloaded
594+
# update any parameters which may have changed
595+
for module, (hook, offload) in onloaded_modules.items():
596+
hook.offload = offload
597+
for name, param in module.named_parameters(recurse=False):
598+
update_offload_parameter(module, name, param.data)
599+
hook.post_forward(module, None)
600+
601+
525602
""" Upstreamed Functions """
526603

527604

@@ -589,3 +666,7 @@ def align_module_device(
589666

590667
else:
591668
yield
669+
670+
671+
# (1): Since we cannot know which pointers are shared when we add parameters in an
672+
# online way, assume that all pointers are shared. This has virtually no runtime cost

0 commit comments

Comments
 (0)