Skip to content

Commit 830140b

Browse files
committed
added params for pytorch model in viya4model load, import model,
1 parent ae049df commit 830140b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def _write_imports(
405405
mojo_model: Optional[bool] = False,
406406
binary_h2o_model: Optional[bool] = False,
407407
tf_model: Optional[bool] = False,
408+
pytorch_model: Optional[bool] = False,
408409
binary_string: Optional[str] = None,
409410
) -> None:
410411
"""
@@ -427,6 +428,9 @@ def _write_imports(
427428
tf_model : bool, optional
428429
Flag to indicate that the model is a tensorflow model. The default value
429430
is None.
431+
pytorch_model : bool, optional
432+
Flag to indicate that the model is a pytorch model. The default value
433+
is None.
430434
binary_string : str, optional
431435
A binary representation of the Python model object. The default value is
432436
None.
@@ -475,6 +479,8 @@ def _write_imports(
475479
import tensorflow as tf
476480
477481
"""
482+
elif pytorch_model:
483+
cls.score_code += "import math\nimport torch\nimport pandas as pd\nimport numpy as np\nfrom pathlib import Path\n\n"
478484
elif binary_string:
479485
cls.score_code += (
480486
f'import codecs\n\nbinary_string = "{binary_string}"'
@@ -578,6 +584,7 @@ def _viya4_model_load(
578584
pickle_type: Optional[str] = None,
579585
mojo_model: Optional[bool] = False,
580586
binary_h2o_model: Optional[bool] = False,
587+
pytorch_model: Optional[bool] = False,
581588
tf_keras_model: Optional[bool] = False,
582589
tf_core_model: Optional[bool] = False,
583590
) -> str:
@@ -598,6 +605,9 @@ def _viya4_model_load(
598605
binary_h2o_model : boolean, optional
599606
Flag to indicate that the model is a H2O.ai binary model. The default value
600607
is None.
608+
pytorch_model : boolean, optional
609+
Flag to indicate that the model is a pytorch model. The default value
610+
is None.
601611
tf_keras_model : boolean, optional
602612
Flag to indicate that the model is a tensorflow keras model. The default
603613
value is False.
@@ -633,6 +643,8 @@ def _viya4_model_load(
633643
f"{'':8}model = h2o.load(str(Path(settings.pickle_path) / "
634644
f"{model_file_name}))\n\n"
635645
)
646+
elif pytorch_model:
647+
cls.score_code += ( "model = torch.load(path) ")
636648
elif tf_keras_model:
637649
cls.score_code += (
638650
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "

0 commit comments

Comments
 (0)