Skip to content

Commit 3ac19fa

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized-compression-memory
2 parents 1862e0f + 4438d08 commit 3ac19fa

File tree

7 files changed

+82
-14
lines changed

7 files changed

+82
-14
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ jobs:
7676

7777
- name: build
7878
id: build
79-
uses: neuralmagic/nm-actions/actions/build-ml-whl@v1.18.0
79+
uses: neuralmagic/nm-actions/actions/build-ml-whl@c7e5a66c382104e1beadcb7dadf429f8ab15b344 # v1.20.0
8080
with:
8181
dev: false
8282
release: ${{ inputs.wf_category == 'RELEASE' }}

.github/workflows/test-check.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ jobs:
1212
python-tests:
1313
runs-on: ubuntu-24.04
1414
steps:
15-
- uses: actions/setup-python@v4
15+
- uses: actions/setup-python@v5
1616
with:
1717
python-version: '3.10'
18-
- uses: actions/checkout@v3
18+
- uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0
21+
fetch-tags: true
1922
- name: Set Env
2023
run: |
2124
pip3 install --upgrade pip && pip3 install --upgrade setuptools
22-
pip3 install virtualenv
23-
virtualenv venv
24-
source venv/bin/activate
2525
- name: "⚙️ Install dependencies"
2626
run: pip3 install .[dev,accelerate]
2727
- name: "🔬 Running tests"

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "setuptools_scm>8"]
2+
requires = ["setuptools", "wheel", "setuptools_scm==8.2.0"]
33
build-backend = "setuptools.build_meta"
44

5-
[tool.setuptools_scm]
6-
version_file = "src/compressed_tensors/version.py"
7-
85
[tool.black]
96
line-length = 88
107
target-version = ['py36']

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _setup_extras() -> Dict:
101101
use_scm_version={
102102
"version_scheme": version_func,
103103
"local_scheme": localversion_func,
104+
"version_file": "src/compressed_tensors/version.py",
104105
},
105106
author="Neuralmagic, Inc.",
106107
author_email="support@neuralmagic.com",

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def decompress(
163163
self,
164164
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
165165
names_to_scheme: Dict[str, QuantizationScheme],
166-
device: torch.device = "cpu",
166+
device: str = "cpu",
167167
) -> Generator[Tuple[str, Tensor], None, None]:
168168
"""
169169
Reads a compressed state dict located at path_to_model_or_tensors
@@ -172,7 +172,8 @@ def decompress(
172172
:param path_to_model_or_tensors: path to compressed safetensors model (directory
173173
with one or more safetensors files) or compressed tensors file
174174
:param names_to_scheme: quantization scheme for each quantized weight
175-
:param device: optional device to load intermediate weights into
175+
:param device: optional device to load intermediate weights into (must be `str`,
176+
not `torch.device`)
176177
:return: compressed state dict
177178
"""
178179
if isinstance(path_to_model_or_tensors, (str, Path)):
@@ -189,7 +190,7 @@ def _decompress_from_path(
189190
self,
190191
path_to_model: Union[str, Path, Dict[str, Any]],
191192
names_to_scheme: Dict[str, QuantizationScheme],
192-
device: torch.device,
193+
device: str,
193194
):
194195
weight_mappings = get_nested_weight_mappings(
195196
path_to_model, self.compression_param_names

src/compressed_tensors/utils/offload.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import contextlib
2929
import warnings
3030
from functools import wraps
31-
from typing import Any, Callable, Dict, Literal, Optional, Union
31+
from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union
3232

3333
import torch
3434

@@ -67,6 +67,8 @@
6767
"delete_offload_parameter",
6868
"has_offloaded_params",
6969
"disable_hf_hook",
70+
"disable_offload",
71+
"align_modules",
7072
"align_module_device",
7173
]
7274

@@ -344,6 +346,43 @@ def delete_from_weights_map(
344346
)
345347

346348

349+
@contextlib.contextmanager
350+
def disable_offload(module: torch.nn.Module):
351+
"""
352+
Context manager to disable module onloading and offloading. Parameters will stay on
353+
their current device
354+
355+
:param module: module to disable offloading for
356+
"""
357+
if has_offloaded_params(module):
358+
module._hf_hook.offload = False
359+
yield
360+
module._hf_hook.offload = True
361+
else:
362+
yield
363+
364+
365+
@contextlib.contextmanager
366+
def align_modules(
367+
modules: Union[torch.nn.Module, Iterable[torch.nn.Module]],
368+
execution_device: Optional[torch.device] = None,
369+
):
370+
"""
371+
Context manager for onloading modules to a device, and disabling onload and offload
372+
attempts triggered by forward calls. Used for sequential onloading of layers
373+
374+
:param modules: `torch.nn.Module` or iterable of `torch.nn.Module`s to onload
375+
:param execution_device: device to onload to
376+
"""
377+
modules = (modules,) if isinstance(modules, torch.nn.Module) else modules
378+
379+
with contextlib.ExitStack() as stack:
380+
for module in modules:
381+
stack.enter_context(align_module_device(module, execution_device))
382+
stack.enter_context(disable_offload(module)) # disable redundant onloading
383+
yield
384+
385+
347386
""" Upstreamed Functions """
348387

349388

tests/test_utils/test_offload.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from compressed_tensors.utils import (
1717
align_module_device,
18+
align_modules,
1819
delete_offload_parameter,
1920
disable_hf_hook,
2021
get_execution_device,
@@ -248,6 +249,35 @@ def test_disable_hf_hook_model_recurse():
248249
assert hasattr(module2, "_hf_hook")
249250

250251

252+
@requires_accelerate()
253+
def test_align_modules():
254+
from accelerate.hooks import attach_align_device_hook
255+
256+
module0 = ExampleModule()
257+
module1 = ExampleModule()
258+
module2 = ExampleModule()
259+
model = torch.nn.Sequential(module0, torch.nn.Sequential(module1, module2))
260+
attach_align_device_hook(
261+
model,
262+
execution_device=torch.device("cpu"),
263+
offload=True,
264+
weights_map=model.state_dict(),
265+
)
266+
267+
assert module0.a.device == torch.device("meta")
268+
assert module1.a.device == torch.device("meta")
269+
assert module2.a.device == torch.device("meta")
270+
271+
with align_modules((module0, module1)):
272+
assert module0.a.device != torch.device("meta")
273+
assert module1.a.device != torch.device("meta")
274+
assert module2.a.device == torch.device("meta")
275+
276+
assert module0.a.device == torch.device("meta")
277+
assert module1.a.device == torch.device("meta")
278+
assert module2.a.device == torch.device("meta")
279+
280+
251281
@requires_accelerate()
252282
def test_offload_to_weights_map():
253283
from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset

0 commit comments

Comments
 (0)