Skip to content

Commit 2e59904

Browse files
committed
write_file_metadata_json - added optional parameter for pytorch models
1 parent 9e1e2ae commit 2e59904

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def write_file_metadata_json(
470470
model_prefix: str,
471471
json_path: Union[str, Path, None] = None,
472472
is_h2o_model: Optional[bool] = False,
473+
is_pytorch_model: Optional[bool] = False,
473474
is_tf_keras_model: Optional[bool] = False,
474475
) -> Union[dict, None]:
475476
"""
@@ -489,6 +490,10 @@ def write_file_metadata_json(
489490
Sets whether the model metadata is associated with an H2O.ai model. If set
490491
as True, the MOJO model file will be set as a score resource. The default
491492
value is False.
493+
is_pytorch_model : bool, optional
494+
Sets whether the model metadata is associated with an H2O.ai model. If set
495+
as True, the .pth file will be set as a score resource. The default
496+
value is False.
492497
493498
Returns
494499
-------
@@ -505,6 +510,8 @@ def write_file_metadata_json(
505510
dict_list.append({"role": "scoreResource", "name": model_prefix + ".mojo"})
506511
elif is_tf_keras_model:
507512
dict_list.append({"role": "scoreResource", "name": model_prefix + ".h5"})
513+
elif is_pytorch_model:
514+
dict_list.append(({"role": "scoreResource", "name": model_prefix + ".pth"}))
508515
else:
509516
dict_list.append(
510517
{"role": "scoreResource", "name": model_prefix + ".pickle"}

0 commit comments

Comments
 (0)