168
168
"training_script" : None ,
169
169
}
170
170
171
+ INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172
+ TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173
+ DEFAULT_PYTHON_VERSION = "3.8"
174
+ MODEL_FILE_NAME = "fake_model_name"
175
+ FAKE_MD_URL = "http://<model-deployment-url>"
176
+
177
+
178
+ def _prepare (model ):
179
+ model .prepare (
180
+ inference_conda_env = INFERENCE_CONDA_ENV ,
181
+ inference_python_version = DEFAULT_PYTHON_VERSION ,
182
+ training_conda_env = TRAINING_CONDA_ENV ,
183
+ training_python_version = DEFAULT_PYTHON_VERSION ,
184
+ model_file_name = MODEL_FILE_NAME ,
185
+ force_overwrite = True ,
186
+ )
187
+
171
188
172
189
class TestEstimator :
173
190
def predict (self , x ):
174
191
return x ** 2
175
192
176
193
177
194
class TestGenericModel :
178
-
179
195
iris = load_iris ()
180
196
X , y = iris .data , iris .target
181
197
X_train , X_test , y_train , y_test = train_test_split (X , y )
@@ -298,16 +314,22 @@ def test_prepare_both_conda_env(self, mock_signer):
298
314
)
299
315
300
316
@patch ("ads.common.auth.default_signer" )
301
- def test_verify_without_reload (self , mock_signer ):
302
- """Test verify input data without reload artifacts ."""
317
+ def test_prepare_with_custom_scorepy (self , mock_signer ):
318
+ """Test prepare a trained model with custom score.py ."""
303
319
self .generic_model .prepare (
304
- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
305
- inference_python_version = "3.6" ,
306
- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
307
- training_python_version = "3.7" ,
320
+ INFERENCE_CONDA_ENV ,
308
321
model_file_name = "fake_model_name" ,
309
- force_overwrite = True ,
322
+ score_py_uri = f" { os . path . dirname ( os . path . abspath ( __file__ )) } /test_files/custom_score.py" ,
310
323
)
324
+ assert os .path .exists (os .path .join ("fake_folder" , "score.py" ))
325
+
326
+ prediction = self .generic_model .verify (data = "test" )["prediction" ]
327
+ assert prediction == "This is a custom score.py."
328
+
329
+ @patch ("ads.common.auth.default_signer" )
330
+ def test_verify_without_reload (self , mock_signer ):
331
+ """Test verify input data without reload artifacts."""
332
+ _prepare (self .generic_model )
311
333
self .generic_model .verify (self .X_test .tolist ())
312
334
313
335
with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
@@ -317,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
317
339
@patch ("ads.common.auth.default_signer" )
318
340
def test_verify (self , mock_signer ):
319
341
"""Test verify input data"""
320
- self .generic_model .prepare (
321
- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
322
- inference_python_version = "3.6" ,
323
- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
324
- training_python_version = "3.7" ,
325
- model_file_name = "fake_model_name" ,
326
- force_overwrite = True ,
327
- )
342
+ _prepare (self .generic_model )
328
343
prediction_1 = self .generic_model .verify (self .X_test .tolist ())
329
344
assert isinstance (prediction_1 , dict ), "Failed to verify json payload."
330
345
331
- prediction_2 = self .generic_model .verify (self .X_test .tolist ())
332
- assert isinstance (prediction_2 , dict ), "Failed to verify input data."
333
-
334
346
def test_reload (self ):
335
347
"""test the reload."""
336
348
pass
@@ -622,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
622
634
== random_name [:- 9 ]
623
635
)
624
636
637
+ @pytest .mark .parametrize ("input_data" , [(X_test .tolist ())])
638
+ @patch ("ads.common.auth.default_signer" )
639
+ def test_predict_locally (self , mock_signer , input_data ):
640
+ _prepare (self .generic_model )
641
+ test_result = self .generic_model .predict (data = input_data , local = True )
642
+ expected_result = self .generic_model .estimator .predict (input_data ).tolist ()
643
+ assert (
644
+ test_result ["prediction" ] == expected_result
645
+ ), "Failed to verify input data."
646
+
647
+ with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
648
+ self .generic_model .predict (
649
+ data = input_data , local = True , reload_artifacts = False
650
+ )
651
+ mock_reload .assert_not_called ()
652
+
625
653
@patch .object (ModelDeployment , "predict" )
626
654
@patch ("ads.common.auth.default_signer" )
627
655
@patch ("ads.common.oci_client.OCIClientFactory" )
656
+ @patch (
657
+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
658
+ return_value = FAKE_MD_URL ,
659
+ )
628
660
def test_predict_with_not_active_deployment_fail (
629
- self , mock_client , mock_signer , mock_predict
661
+ self , mock_url , mock_client , mock_signer , mock_predict
630
662
):
631
663
"""Ensures predict model fails in case of model deployment is not in an active state."""
632
664
with pytest .raises (NotActiveDeploymentError ):
@@ -646,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(
646
678
647
679
@patch ("ads.common.auth.default_signer" )
648
680
@patch ("ads.common.oci_client.OCIClientFactory" )
649
- def test_predict_bytes_success (self , mock_client , mock_signer ):
681
+ @patch (
682
+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
683
+ return_value = FAKE_MD_URL ,
684
+ )
685
+ def test_predict_bytes_success (self , mock_url , mock_client , mock_signer ):
650
686
"""Ensures predict model passes with bytes input."""
651
687
with patch .object (
652
688
ModelDeployment , "state" , new_callable = PropertyMock
@@ -655,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
655
691
with patch .object (ModelDeployment , "predict" ) as mock_predict :
656
692
mock_predict .return_value = {"result" : "result" }
657
693
self .generic_model .model_deployment = ModelDeployment (
658
- model_deployment_id = "test"
694
+ model_deployment_id = "test" ,
659
695
)
660
696
# self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
661
697
self .generic_model ._as_onnx = False
@@ -668,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
668
704
669
705
@patch ("ads.common.auth.default_signer" )
670
706
@patch ("ads.common.oci_client.OCIClientFactory" )
671
- def test_predict_success (self , mock_client , mock_signer ):
707
+ @patch (
708
+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
709
+ return_value = FAKE_MD_URL ,
710
+ )
711
+ def test_predict_success (self , mock_url , mock_client , mock_signer ):
672
712
"""Ensures predict model passes with valid input parameters."""
673
713
with patch .object (
674
714
ModelDeployment , "state" , new_callable = PropertyMock
@@ -785,7 +825,11 @@ def test_from_model_artifact(
785
825
786
826
@patch ("ads.common.auth.default_signer" )
787
827
@patch ("ads.common.oci_client.OCIClientFactory" )
788
- def test_predict_success__serialize_input (self , mock_client , mock_signer ):
828
+ @patch (
829
+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
830
+ return_value = FAKE_MD_URL ,
831
+ )
832
+ def test_predict_success__serialize_input (self , mock_url , mock_client , mock_signer ):
789
833
"""Ensures predict model passes with valid input parameters."""
790
834
791
835
df = pd .DataFrame ([1 , 2 , 3 ])
@@ -795,7 +839,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
795
839
with patch .object (
796
840
GenericModel , "get_data_serializer"
797
841
) as mock_get_data_serializer :
798
-
799
842
mock_get_data_serializer .return_value .data = df .to_json ()
800
843
mock_state .return_value = ModelDeploymentState .ACTIVE
801
844
with patch .object (ModelDeployment , "predict" ) as mock_predict :
@@ -1782,7 +1825,6 @@ def test_upload_artifact_fail(self):
1782
1825
def test_upload_artifact_success (self ):
1783
1826
"""Tests uploading model artifacts to the provided `uri`."""
1784
1827
with tempfile .TemporaryDirectory () as tmp_dir :
1785
-
1786
1828
# copy test artifacts to the temp folder
1787
1829
shutil .copytree (
1788
1830
os .path .join (self .curr_dir , "test_files/valid_model_artifacts" ),
0 commit comments