@@ -405,6 +405,7 @@ def _write_imports(
405
405
mojo_model : Optional [bool ] = False ,
406
406
binary_h2o_model : Optional [bool ] = False ,
407
407
tf_model : Optional [bool ] = False ,
408
+ pytorch_model : Optional [bool ] = False ,
408
409
binary_string : Optional [str ] = None ,
409
410
) -> None :
410
411
"""
@@ -427,6 +428,9 @@ def _write_imports(
427
428
tf_model : bool, optional
428
429
Flag to indicate that the model is a tensorflow model. The default value
429
430
is None.
431
+ pytorch_model : bool, optional
432
+ Flag to indicate that the model is a pytorch model. The default value
433
+ is None.
430
434
binary_string : str, optional
431
435
A binary representation of the Python model object. The default value is
432
436
None.
@@ -475,6 +479,8 @@ def _write_imports(
475
479
import tensorflow as tf
476
480
477
481
"""
482
+ elif pytorch_model :
483
+ cls .score_code += "import math\n import torch\n import pandas as pd\n import numpy as np\n from pathlib import Path\n \n "
478
484
elif binary_string :
479
485
cls .score_code += (
480
486
f'import codecs\n \n binary_string = "{ binary_string } "'
@@ -578,6 +584,7 @@ def _viya4_model_load(
578
584
pickle_type : Optional [str ] = None ,
579
585
mojo_model : Optional [bool ] = False ,
580
586
binary_h2o_model : Optional [bool ] = False ,
587
+ pytorch_model : Optional [bool ] = False ,
581
588
tf_keras_model : Optional [bool ] = False ,
582
589
tf_core_model : Optional [bool ] = False ,
583
590
) -> str :
@@ -598,6 +605,9 @@ def _viya4_model_load(
598
605
binary_h2o_model : boolean, optional
599
606
Flag to indicate that the model is a H2O.ai binary model. The default value
600
607
is None.
608
+ pytorch_model : boolean, optional
609
+ Flag to indicate that the model is a pytorch model. The default value
610
+ is None.
601
611
tf_keras_model : boolean, optional
602
612
Flag to indicate that the model is a tensorflow keras model. The default
603
613
value is False.
@@ -633,6 +643,8 @@ def _viya4_model_load(
633
643
f"{ '' :8} model = h2o.load(str(Path(settings.pickle_path) / "
634
644
f"{ model_file_name } ))\n \n "
635
645
)
646
+ elif pytorch_model :
647
+ cls .score_code += ( "model = torch.load(path) " )
636
648
elif tf_keras_model :
637
649
cls .score_code += (
638
650
f"model = tf.keras.models.load_model(Path(settings.pickle_path) / "
0 commit comments