Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 9c7c285

Browse files
authored
Keep model in CPU during ONNX export (#1586)
1 parent 53e2f8d commit 9c7c285

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/sparseml/transformers/export.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@
7575
import math
7676
import os
7777
import shutil
78+
from dataclasses import dataclass
7879
from typing import Any, Dict, List, Optional, Union
7980

8081
from torch.nn import Module
8182
from transformers import AutoConfig, AutoTokenizer
83+
from transformers import TrainingArguments as HFTrainingArgs
8284
from transformers.tokenization_utils_base import PaddingStrategy
8385

8486
from sparseml.optim import parse_recipe_variables
@@ -107,6 +109,14 @@
107109
_LOGGER = logging.getLogger(__name__)
108110

109111

112+
@dataclass
113+
class DeviceCPUTrainingArgs(HFTrainingArgs):
114+
@property
115+
def place_model_on_device(self):
116+
# Ensure model remains in CPU during ONNX export
117+
return False
118+
119+
110120
def load_task_model(task: str, model_path: str, config: Any) -> Module:
111121
if task == "masked-language-modeling" or task == "mlm":
112122
return SparseAutoModel.masked_language_modeling_from_pretrained(
@@ -294,15 +304,18 @@ def export_transformer_to_onnx(
294304
_LOGGER.info(f"loaded validation dataset for args {data_args}")
295305

296306
model = model.train()
307+
308+
args = DeviceCPUTrainingArgs(output_dir="tmp_trainer")
297309
trainer = Trainer(
298310
model=model,
311+
args=args,
299312
model_state_path=model_path,
300313
eval_dataset=eval_dataset,
301314
recipe=None,
302315
recipe_args=None,
303316
teacher=None,
304317
)
305-
model = model.cpu()
318+
306319
applied = trainer.apply_manager(epoch=math.inf, checkpoint=None)
307320

308321
if not applied:

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
training_args_dict=training_args.to_dict(),
126126
data_args_dict=asdict(data_args) if data_args else {},
127127
)
128-
if training_args
128+
if training_args and metadata_args
129129
else None
130130
)
131131

@@ -762,7 +762,6 @@ def _get_fake_dataloader(
762762
num_samples: int,
763763
tokenizer: "PreTrainedTokenizerBase", # noqa: F821
764764
):
765-
766765
# Rearrange inputs' keys to match those defined by model foward func, which
767766
# seem to define how the order of inputs is determined in the exported model
768767
forward_args_spec = inspect.getfullargspec(self.model.__class__.forward)
@@ -820,7 +819,6 @@ def __init__(
820819
teacher: Optional[Union[Module, str]] = None,
821820
**kwargs,
822821
):
823-
824822
super().__init__(
825823
model=model,
826824
model_state_path=model_state_path,

0 commit comments

Comments
 (0)