@@ -177,6 +177,7 @@ def write_score_code(
177
177
pickle_type ,
178
178
mojo_model = "mojo_model" in kwargs ,
179
179
binary_h2o_model = "binary_h2o_model" in kwargs ,
180
+ pytorch_model = "pytorch_model" in kwargs ,
180
181
tf_model = "tf_keras_model" in kwargs or "tf_core_model" in kwargs ,
181
182
binary_string = binary_string ,
182
183
)
@@ -188,6 +189,7 @@ def write_score_code(
188
189
model_file_name ,
189
190
pickle_type = pickle_type ,
190
191
mojo_model = "mojo_model" in kwargs ,
192
+ pytorch_model = "pytorch_model" in kwargs ,
191
193
binary_h2o_model = "binary_h2o_model" in kwargs ,
192
194
)
193
195
# As above, but for SAS Viya 4 models
@@ -197,6 +199,7 @@ def write_score_code(
197
199
pickle_type = pickle_type ,
198
200
mojo_model = "mojo_model" in kwargs ,
199
201
binary_h2o_model = "binary_h2o_model" in kwargs ,
202
+ pytorch_model = "pytorch_model" in kwargs ,
200
203
tf_keras_model = "tf_keras_model" in kwargs ,
201
204
tf_core_model = "tf_core_model" in kwargs ,
202
205
)
@@ -501,6 +504,7 @@ def _viya35_model_load(
501
504
model_id : str ,
502
505
model_file_name : str ,
503
506
pickle_type : Optional [str ] = None ,
507
+ pytorch_model : Optional [str ] = None ,
504
508
mojo_model : Optional [bool ] = False ,
505
509
binary_h2o_model : Optional [bool ] = False ,
506
510
) -> str :
@@ -557,6 +561,13 @@ def _viya35_model_load(
557
561
f"{ '' :8} model = h2o.load(str(Path(\" /models/resources/viya/"
558
562
f'{ model_id } /{ model_file_name } ")))'
559
563
)
564
+ elif pytorch_model :
565
+ cls .score_code += ("model = torch.load(path) " )
566
+ return (
567
+ f"{ '' :8} model_path = Path(\" /models/resources/viya/{ model_id } \" )\n "
568
+ f"{ '' :8} with open(model_path / \" model.pth\" , \" rb\" ) as torch_model:\n "
569
+ f"{ '' :12} model = torch.load(torch_model)"
570
+ )
560
571
else :
561
572
cls .score_code += (
562
573
f'model_path = Path("/models/resources/viya/{ model_id } '
@@ -644,7 +655,12 @@ def _viya4_model_load(
644
655
f"{ model_file_name } ))\n \n "
645
656
)
646
657
elif pytorch_model :
647
- cls .score_code += ( "model = torch.load(path) " )
658
+ cls .score_code += ("model = torch.load(path) " )
659
+ return (
660
+ f"{ '' :8} model_path = Path(\" /models/resources/viya/{ model_id } \" )\n "
661
+ f"{ '' :8} with open(model_path / \" model.pth\" , \" rb\" ) as torch_model:\n "
662
+ f"{ '' :12} model = torch.load(torch_model)"
663
+ )
648
664
elif tf_keras_model :
649
665
cls .score_code += (
650
666
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "
0 commit comments