Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit b96a89a

Browse files
authored
Support exporting > 2Gb transformer models (#1514)
* initial commit * initial commit * Delete helpers.py * cleanup * fix an error in the logic * focus on opt when it comes to tasks * initial commit * Delete model.py * cleanup * Apply suggestions from code review
1 parent 5a8a333 commit b96a89a

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

src/sparseml/transformers/export.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@
9191
__all__ = ["export_transformer_to_onnx", "load_task_model"]
9292

9393
MODEL_ONNX_NAME = "model.onnx"
94-
DEPLOYMENT_FILES: List[str] = [
94+
EXTERNAL_ONNX_DATA_NAME = "model.data"
95+
MANDATORY_DEPLOYMENT_FILES: List[str] = [
9596
MODEL_ONNX_NAME,
96-
"tokenizer.json",
9797
"tokenizer_config.json",
9898
"config.json",
9999
]
100+
OPTIONAL_DEPLOYMENT_FILES: List[str] = [EXTERNAL_ONNX_DATA_NAME, "tokenizer.json"]
100101

101102
_LOGGER = logging.getLogger(__name__)
102103

@@ -403,7 +404,9 @@ def create_deployment_folder(
403404

404405
if deployment_files is None:
405406
# set deployment files to default values
406-
deployment_files = copy.deepcopy(DEPLOYMENT_FILES)
407+
deployment_files = copy.deepcopy(
408+
MANDATORY_DEPLOYMENT_FILES + OPTIONAL_DEPLOYMENT_FILES
409+
)
407410
if onnx_file_name != MODEL_ONNX_NAME:
408411
# replace the default onnx model name with the custom one
409412
deployment_files[deployment_files.index(MODEL_ONNX_NAME)] = onnx_file_name
@@ -418,6 +421,12 @@ def create_deployment_folder(
418421
expected_file_path = os.path.join(training_directory, file_name)
419422
deployment_file_path = os.path.join(deployment_folder_dir, file_name)
420423
if not os.path.exists(expected_file_path):
424+
if file_name in OPTIONAL_DEPLOYMENT_FILES:
425+
_LOGGER.warning(
426+
f"Optional file {file_name} not found in {training_directory}. "
427+
f"Skipping copying to deployment folder."
428+
)
429+
continue
421430
raise ValueError(
422431
f"Attempting to copy {file_name} file from {expected_file_path},"
423432
f"but the file does not exits. Make sure that {training_directory} "
@@ -426,6 +435,9 @@ def create_deployment_folder(
426435
if file_name == MODEL_ONNX_NAME:
427436
# moving onnx file from training to deployment directory
428437
shutil.move(expected_file_path, deployment_file_path)
438+
elif file_name == EXTERNAL_ONNX_DATA_NAME:
439+
# moving external onnx tensors from training to deployment directory
440+
shutil.move(expected_file_path, deployment_file_path)
429441
else:
430442
# copying remaining `deployment_files` from training to deployment directory
431443
shutil.copyfile(expected_file_path, deployment_file_path)

0 commit comments

Comments
 (0)