File tree Expand file tree Collapse file tree 4 files changed +9
-5
lines changed Expand file tree Collapse file tree 4 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -171,6 +171,7 @@ class DatasetArguments(CustomDatasetArguments):
171
171
"will execute code present on the Hub on your local machine."
172
172
},
173
173
)
174
+ # --- pipeline arguments --- #
174
175
pipeline : Optional [str ] = field (
175
176
default = "independent" ,
176
177
metadata = {
Original file line number Diff line number Diff line change @@ -44,6 +44,7 @@ def capture_first_layer_intermediates(
44
44
model : Module ,
45
45
first_layer : Module ,
46
46
dataloader : DataLoader ,
47
+ model_device : torch .device = torch .device ("cpu" ),
47
48
mask_padding : bool = True ,
48
49
) -> IntermediatesCache :
49
50
"""
@@ -68,7 +69,7 @@ def capture_first_layer_intermediates(
68
69
desc = "Preparing intermediates cache"
69
70
for batch_index , batch in enumerate (tqdm .tqdm (dataloader , desc = desc )):
70
71
batch = apply_pad_mask_to_batch (batch ) if mask_padding else batch
71
- batch = tensors_to_device (batch , torch . device ( "cpu" ) )
72
+ batch = tensors_to_device (batch , model_device )
72
73
73
74
try :
74
75
model (** batch )
Original file line number Diff line number Diff line change 2
2
3
3
import torch
4
4
import tqdm
5
- from compressed_tensors .utils import disable_offloading
5
+ from compressed_tensors .utils import disable_offloading , get_execution_device
6
6
from torch .utils .data .dataloader import DataLoader
7
7
8
8
from llmcompressor .core import LifecycleCallbacks , active_session
@@ -60,6 +60,7 @@ def __call__(
60
60
61
61
# prepare model for sequential onloading
62
62
dispatch_for_sequential (model )
63
+ model_device = get_execution_device (model )
63
64
64
65
# find layers
65
66
modifiers = session .get_modifiers ()
@@ -71,7 +72,7 @@ def __call__(
71
72
with calibration_forward_context (model ), DisableQuantization (model ):
72
73
# prepare intermediates cache
73
74
intermediates : IntermediatesCache = capture_first_layer_intermediates (
74
- model , layers [0 ], dataloader
75
+ model , layers [0 ], dataloader , model_device
75
76
)
76
77
77
78
num_layers = len (layers )
Original file line number Diff line number Diff line change 1
1
from typing import TYPE_CHECKING
2
2
3
3
import torch
4
- from compressed_tensors .utils import disable_offloading
4
+ from compressed_tensors .utils import disable_offloading , get_execution_device
5
5
from torch .utils .data .dataloader import DataLoader
6
6
from tqdm import tqdm
7
7
@@ -54,6 +54,7 @@ def __call__(
54
54
55
55
# prepare model for sequential onloading
56
56
dispatch_for_sequential (model )
57
+ model_device = get_execution_device (model )
57
58
58
59
# prepare to trace subgraphs
59
60
modifiers = session .get_modifiers ()
@@ -69,7 +70,7 @@ def __call__(
69
70
70
71
with calibration_forward_context (model ), DisableQuantization (model ):
71
72
# prepare intermediates cache
72
- activations = IntermediatesCache .from_dataloader (dataloader )
73
+ activations = IntermediatesCache .from_dataloader (dataloader , model_device )
73
74
74
75
for subgraph_index , subgraph in enumerate (subgraphs ):
75
76
# prepare tqdm description texts
You can’t perform that action at this time.
0 commit comments