Skip to content

Commit 50bb656

Browse files
authored
[Pipelines] infer model device with optional override (#1572)
## Purpose ## * Fix support for deepseekv2.5 * Add more robustness inference for model devices when calibrating ## Prerequisites ## * neuralmagic/compressed-tensors#363 ## Background ## Normally, starting model inputs on the cpu is not an issue for the sequential pipeline, since the sequential pipeline offloads models and offloaded models automatically place inputs on the proper devices. However, the deepseekv2.5 model is an exception, as this model [performs an add operation](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/modeling_deepseek.py#L886) between a module output (`attn_weights` and a model input `attention_mask`) before the model input has a chance to be placed on the proper device. ## Changes ## * Use `model_device` when deciding the onload device for model inputs ## Testing ## * Ran deepseekv2.5 example to completion * TODO: run nightly to confirm other models work with new input device placement --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6800f81 commit 50bb656

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class DatasetArguments(CustomDatasetArguments):
171171
"will execute code present on the Hub on your local machine."
172172
},
173173
)
174+
# --- pipeline arguments --- #
174175
pipeline: Optional[str] = field(
175176
default="independent",
176177
metadata={

src/llmcompressor/pipelines/layer_sequential/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def capture_first_layer_intermediates(
4444
model: Module,
4545
first_layer: Module,
4646
dataloader: DataLoader,
47+
model_device: torch.device = torch.device("cpu"),
4748
mask_padding: bool = True,
4849
) -> IntermediatesCache:
4950
"""
@@ -68,7 +69,7 @@ def capture_first_layer_intermediates(
6869
desc = "Preparing intermediates cache"
6970
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)):
7071
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch
71-
batch = tensors_to_device(batch, torch.device("cpu"))
72+
batch = tensors_to_device(batch, model_device)
7273

7374
try:
7475
model(**batch)

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import tqdm
5-
from compressed_tensors.utils import disable_offloading
5+
from compressed_tensors.utils import disable_offloading, get_execution_device
66
from torch.utils.data.dataloader import DataLoader
77

88
from llmcompressor.core import LifecycleCallbacks, active_session
@@ -60,6 +60,7 @@ def __call__(
6060

6161
# prepare model for sequential onloading
6262
dispatch_for_sequential(model)
63+
model_device = get_execution_device(model)
6364

6465
# find layers
6566
modifiers = session.get_modifiers()
@@ -71,7 +72,7 @@ def __call__(
7172
with calibration_forward_context(model), DisableQuantization(model):
7273
# prepare intermediates cache
7374
intermediates: IntermediatesCache = capture_first_layer_intermediates(
74-
model, layers[0], dataloader
75+
model, layers[0], dataloader, model_device
7576
)
7677

7778
num_layers = len(layers)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
import torch
4-
from compressed_tensors.utils import disable_offloading
4+
from compressed_tensors.utils import disable_offloading, get_execution_device
55
from torch.utils.data.dataloader import DataLoader
66
from tqdm import tqdm
77

@@ -54,6 +54,7 @@ def __call__(
5454

5555
# prepare model for sequential onloading
5656
dispatch_for_sequential(model)
57+
model_device = get_execution_device(model)
5758

5859
# prepare to trace subgraphs
5960
modifiers = session.get_modifiers()
@@ -69,7 +70,7 @@ def __call__(
6970

7071
with calibration_forward_context(model), DisableQuantization(model):
7172
# prepare intermediates cache
72-
activations = IntermediatesCache.from_dataloader(dataloader)
73+
activations = IntermediatesCache.from_dataloader(dataloader, model_device)
7374

7475
for subgraph_index, subgraph in enumerate(subgraphs):
7576
# prepare tqdm description texts

0 commit comments

Comments
 (0)