Skip to content

Commit 941deac

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/deepseek-v3
2 parents 0dc2381 + 50bb656 commit 941deac

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class DatasetArguments(CustomDatasetArguments):
171171
"will execute code present on the Hub on your local machine."
172172
},
173173
)
174+
# --- pipeline arguments --- #
174175
pipeline: Optional[str] = field(
175176
default="independent",
176177
metadata={

src/llmcompressor/pipelines/layer_sequential/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def capture_first_layer_intermediates(
4444
model: Module,
4545
first_layer: Module,
4646
dataloader: DataLoader,
47+
model_device: torch.device = torch.device("cpu"),
4748
mask_padding: bool = True,
4849
) -> IntermediatesCache:
4950
"""
@@ -68,7 +69,7 @@ def capture_first_layer_intermediates(
6869
desc = "Preparing intermediates cache"
6970
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)):
7071
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch
71-
batch = tensors_to_device(batch, torch.device("cpu"))
72+
batch = tensors_to_device(batch, model_device)
7273

7374
try:
7475
model(**batch)

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import tqdm
5-
from compressed_tensors.utils import disable_offloading
5+
from compressed_tensors.utils import disable_offloading, get_execution_device
66
from torch.utils.data.dataloader import DataLoader
77

88
from llmcompressor.core import LifecycleCallbacks, active_session
@@ -60,6 +60,7 @@ def __call__(
6060

6161
# prepare model for sequential onloading
6262
dispatch_for_sequential(model)
63+
model_device = get_execution_device(model)
6364

6465
# find layers
6566
modifiers = session.get_modifiers()
@@ -71,7 +72,7 @@ def __call__(
7172
with calibration_forward_context(model), DisableQuantization(model):
7273
# prepare intermediates cache
7374
intermediates: IntermediatesCache = capture_first_layer_intermediates(
74-
model, layers[0], dataloader
75+
model, layers[0], dataloader, model_device
7576
)
7677

7778
num_layers = len(layers)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING
22

33
import torch
4-
from compressed_tensors.utils import disable_offloading
4+
from compressed_tensors.utils import disable_offloading, get_execution_device
55
from torch.utils.data.dataloader import DataLoader
66
from tqdm import tqdm
77

@@ -54,6 +54,7 @@ def __call__(
5454

5555
# prepare model for sequential onloading
5656
dispatch_for_sequential(model)
57+
model_device = get_execution_device(model)
5758

5859
# prepare to trace subgraphs
5960
modifiers = session.get_modifiers()
@@ -69,7 +70,7 @@ def __call__(
6970

7071
with calibration_forward_context(model), DisableQuantization(model):
7172
# prepare intermediates cache
72-
activations = IntermediatesCache.from_dataloader(dataloader)
73+
activations = IntermediatesCache.from_dataloader(dataloader, model_device)
7374

7475
for subgraph_index, subgraph in enumerate(subgraphs):
7576
# prepare tqdm description texts

0 commit comments

Comments
 (0)