Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 934ec0e

Browse files
authored
Enable export paths for LLMs [CodeGen, OPT, Bloom] (#1562)
* initital commit * missing config arg
1 parent ce43c7d commit 934ec0e

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/sparseml/transformers/export.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
136136
model_type="model",
137137
)
138138

139+
if task == "text-generation":
140+
return SparseAutoModel.text_generation_from_pretrained(
141+
model_name_or_path=model_path,
142+
config=config,
143+
model_type="model",
144+
)
145+
139146
raise ValueError(f"unrecognized task given of {task}")
140147

141148

@@ -263,6 +270,9 @@ def export_transformer_to_onnx(
263270
tokenizer = AutoTokenizer.from_pretrained(
264271
model_path, model_max_length=sequence_length
265272
)
273+
if task == "text-generation":
274+
tokenizer.pad_token = tokenizer.eos_token
275+
266276
model = load_task_model(task, model_path, config)
267277
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")
268278

@@ -353,12 +363,14 @@ def export_transformer_to_onnx(
353363
# run export
354364
model = model.eval()
355365
onnx_file_path = os.path.join(model_path, onnx_file_name)
366+
kwargs = {"input_names": list(inputs.keys())} if task == "text-generation" else {}
356367

357368
export_onnx(
358369
model,
359370
inputs,
360371
onnx_file_path,
361372
convert_qat=convert_qat,
373+
**kwargs,
362374
)
363375
_LOGGER.info(f"ONNX exported to {onnx_file_path}")
364376

src/sparseml/transformers/utils/model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
from torch.nn import Module
2222
from transformers import (
23+
AutoModelForCausalLM,
2324
AutoModelForMaskedLM,
2425
AutoModelForQuestionAnswering,
2526
AutoModelForSequenceClassification,
@@ -235,6 +236,41 @@ def text_classification_from_pretrained_distil(
235236

236237
return model, teacher
237238

239+
@staticmethod
240+
def text_generation_from_pretrained(
241+
model_name_or_path: str,
242+
model_type: str,
243+
**kwargs,
244+
) -> Module:
245+
"""
246+
:param model_name_or_path: the name of or path to the model to load
247+
:param model_type: specify the type of model loaded for logging;
248+
ex one of [model, student, teacher]
249+
:param kwargs: keyword arguments to pass through to the AutoModel call
250+
:return: the created model for text generation
251+
"""
252+
SparseAutoModel._check_tf(model_name_or_path)
253+
if not kwargs:
254+
kwargs = {}
255+
kwargs["from_tf"] = False
256+
delayed = False
257+
if "state_dict" not in kwargs:
258+
kwargs["state_dict"], delayed = SparseAutoModel._loadable_state_dict(
259+
model_name_or_path
260+
)
261+
# Export decoder model without kv cache support
262+
kwargs["config"].is_decoder = True
263+
kwargs["config"].use_cache = False
264+
kwargs["config"].use_past = False
265+
266+
model = AutoModelForCausalLM.from_pretrained(
267+
model_name_or_path,
268+
**kwargs,
269+
)
270+
SparseAutoModel.log_model_load(model, model_name_or_path, model_type, delayed)
271+
272+
return model
273+
238274
@staticmethod
239275
def token_classification_from_pretrained(
240276
model_name_or_path: str,

0 commit comments

Comments
 (0)