Skip to content

Commit d4c013c

Browse files
committed
Update score code generation for Pytorch to include proper paths and returns for inner score function usage
1 parent 6ce0a10 commit d4c013c

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,13 @@ def _viya35_model_load(
562562
f'{model_id}/{model_file_name}")))'
563563
)
564564
elif pytorch_model:
565-
cls.score_code += ("model = torch.load(path) ")
565+
cls.score_code += (
566+
f"model = torch.load(\"/models/resources/viya/{model_id}/\" + "
567+
f"{model_file_name})\n\n"
568+
)
566569
return (
567-
f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}\")\n"
568-
f"{'':8}with open(model_path / \"model.pth\", \"rb\") as torch_model:\n"
569-
f"{'':12}model = torch.load(torch_model)"
570+
f"{'':8}model = torch.load(\"/models/resources/viya/{model_id}/\" + "
571+
f"{model_file_name})\n\n"
570572
)
571573
else:
572574
cls.score_code += (
@@ -655,11 +657,13 @@ def _viya4_model_load(
655657
f"{model_file_name}))\n\n"
656658
)
657659
elif pytorch_model:
658-
cls.score_code += ("model = torch.load(path) ")
660+
cls.score_code += (
661+
f"model = torch.load(Path(settings.pickle_path) / "
662+
f"{model_file_name})\n\n"
663+
)
659664
return (
660-
f"{'':8}model_path = Path(\"/models/resources/viya/{model_id}\")\n"
661-
f"{'':8}with open(model_path / \"model.pth\", \"rb\") as torch_model:\n"
662-
f"{'':12}model = torch.load(torch_model)"
665+
f"{'':8}model = torch.load(Path(settings.pickle_path) / "
666+
f"{model_file_name})\n\n"
663667
)
664668
elif tf_keras_model:
665669
cls.score_code += (

0 commit comments

Comments
 (0)