Skip to content

Commit 6ce0a10

Browse files
committed
write_score_code param add.
1 parent 830140b commit 6ce0a10

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def write_score_code(
177177
pickle_type,
178178
mojo_model="mojo_model" in kwargs,
179179
binary_h2o_model="binary_h2o_model" in kwargs,
180+
pytorch_model="pytorch_model" in kwargs,
180181
tf_model="tf_keras_model" in kwargs or "tf_core_model" in kwargs,
181182
binary_string=binary_string,
182183
)
@@ -188,6 +189,7 @@ def write_score_code(
188189
model_file_name,
189190
pickle_type=pickle_type,
190191
mojo_model="mojo_model" in kwargs,
192+
pytorch_model="pytorch_model" in kwargs,
191193
binary_h2o_model="binary_h2o_model" in kwargs,
192194
)
193195
# As above, but for SAS Viya 4 models
@@ -197,6 +199,7 @@ def write_score_code(
197199
pickle_type=pickle_type,
198200
mojo_model="mojo_model" in kwargs,
199201
binary_h2o_model="binary_h2o_model" in kwargs,
202+
pytorch_model="pytorch_model" in kwargs,
200203
tf_keras_model="tf_keras_model" in kwargs,
201204
tf_core_model="tf_core_model" in kwargs,
202205
)
@@ -501,6 +504,7 @@ def _viya35_model_load(
501504
model_id: str,
502505
model_file_name: str,
503506
pickle_type: Optional[str] = None,
507+
pytorch_model: Optional[str] = None,
504508
mojo_model: Optional[bool] = False,
505509
binary_h2o_model: Optional[bool] = False,
506510
) -> str:
@@ -557,6 +561,13 @@ def _viya35_model_load(
557561
f"{'':8}model = h2o.load(str(Path(\"/models/resources/viya/"
558562
f'{model_id}/{model_file_name}")))'
559563
)
564+
elif pytorch_model:
565+
cls.score_code += ("model = torch.load(path) ")
566+
return (
567+
f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}\")\n"
568+
f"{'':8}with open(model_path / \"model.pth\", \"rb\") as torch_model:\n"
569+
f"{'':12}model = torch.load(torch_model)"
570+
)
560571
else:
561572
cls.score_code += (
562573
f'model_path = Path("/models/resources/viya/{model_id}'
@@ -644,7 +655,12 @@ def _viya4_model_load(
644655
f"{model_file_name}))\n\n"
645656
)
646657
elif pytorch_model:
647-
cls.score_code += ( "model = torch.load(path) ")
658+
cls.score_code += ("model = torch.load(path) ")
659+
return (
660+
f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}\")\n"
661+
f"{'':8}with open(model_path / \"model.pth\", \"rb\") as torch_model:\n"
662+
f"{'':12}model = torch.load(torch_model)"
663+
)
648664
elif tf_keras_model:
649665
cls.score_code += (
650666
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "

0 commit comments

Comments
 (0)