Skip to content

Commit 07476cf

Browse files
authored
Merge branch 'develop' into jobs_yaml
2 parents 5a8c311 + 230cde6 commit 07476cf

File tree

7 files changed

+180
-39
lines changed

7 files changed

+180
-39
lines changed

.github/workflows/publish-to-pypi.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: "[DO NOT TRIGGER] Publish to PyPI"
2+
3+
# To run this workflow manually from the Actions tab
4+
on: workflow_dispatch
5+
6+
jobs:
7+
build-n-publish:
8+
name: Build and publish Python 🐍 distribution 📦 to PyPI
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- uses: actions/checkout@v3
13+
- name: Set up Python
14+
uses: actions/setup-python@v4
15+
with:
16+
python-version: "3.x"
17+
- name: Build distribution 📦
18+
run: |
19+
pip install wheel
20+
make dist
21+
- name: Validate
22+
run: |
23+
pip install dist/*.whl
24+
python -c "import ads;"
25+
## To run publish to test PyPI secret with token needs to be added,
26+
## this one GH_ADS_TESTPYPI_TOKEN - removed after initial test.
27+
## Project name also needed to be updated in setup.py - setup(name="test_oracle_ads", ...),
28+
## regular name is occupied by former developer and can't be used for testing
29+
# - name: Publish distribution 📦 to Test PyPI
30+
# env:
31+
# TWINE_USERNAME: __token__
32+
# TWINE_PASSWORD: ${{ secrets.GH_ADS_TESTPYPI_TOKEN }}
33+
# run: |
34+
# pip install twine
35+
# twine upload -r testpypi dist/* -u $TWINE_USERNAME -p $TWINE_PASSWORD
36+
- name: Publish distribution 📦 to PyPI
37+
env:
38+
TWINE_USERNAME: __token__
39+
TWINE_PASSWORD: ${{ secrets.GH_ADS_PYPI_TOKEN }}
40+
run: |
41+
pip install twine
42+
twine upload dist/* -u $TWINE_USERNAME -p $TWINE_PASSWORD

ads/model/generic_model.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def prepare(
789789
ignore_pending_changes: bool = True,
790790
max_col_num: int = DATA_SCHEMA_MAX_COL_NUM,
791791
ignore_conda_error: bool = False,
792+
score_py_uri: str = None,
792793
**kwargs: Dict,
793794
) -> "GenericModel":
794795
"""Prepare and save the score.py, serialized model and runtime.yaml file.
@@ -841,6 +842,10 @@ def prepare(
841842
number of features(columns).
842843
ignore_conda_error: (bool, optional). Defaults to False.
843844
Parameter to ignore error when collecting conda information.
845+
score_py_uri: (str, optional). Defaults to None.
846+
The uri of the customized score.py, which can be local path or OCI object storage URI.
847+
When provide with this attibute, the `score.py` will not be auto generated, and the
848+
provided `score.py` will be added into artifact_dir.
844849
kwargs:
845850
impute_values: (dict, optional).
846851
The dictionary where the key is the column index(or names is accepted
@@ -1001,13 +1006,22 @@ def prepare(
10011006
jinja_template_filename = (
10021007
"score-pkl" if self._serialize else "score_generic"
10031008
)
1004-
self.model_artifact.prepare_score_py(
1005-
jinja_template_filename=jinja_template_filename,
1006-
model_file_name=self.model_file_name,
1007-
data_deserializer=self.model_input_serializer.name,
1008-
model_serializer=self.model_save_serializer.name,
1009-
**{**kwargs, **self._score_args},
1010-
)
1009+
1010+
if score_py_uri:
1011+
utils.copy_file(
1012+
uri_src=score_py_uri,
1013+
uri_dst=os.path.join(self.artifact_dir, "score.py"),
1014+
force_overwrite=force_overwrite,
1015+
auth=self.auth
1016+
)
1017+
else:
1018+
self.model_artifact.prepare_score_py(
1019+
jinja_template_filename=jinja_template_filename,
1020+
model_file_name=self.model_file_name,
1021+
data_deserializer=self.model_input_serializer.name,
1022+
model_serializer=self.model_save_serializer.name,
1023+
**{**kwargs, **self._score_args},
1024+
)
10111025

10121026
self._summary_status.update_status(
10131027
detail="Generated score.py", status=ModelState.DONE.value
@@ -2483,6 +2497,7 @@ def predict(
24832497
self,
24842498
data: Any = None,
24852499
auto_serialize_data: bool = False,
2500+
local: bool = False,
24862501
**kwargs,
24872502
) -> Dict[str, Any]:
24882503
"""Returns prediction of input data run against the model deployment endpoint.
@@ -2507,6 +2522,8 @@ def predict(
25072522
Whether to auto serialize input data. Defauls to `False` for GenericModel, and `True` for other frameworks.
25082523
`data` required to be json serializable if `auto_serialize_data=False`.
25092524
If `auto_serialize_data` set to True, data will be serialized before sending to model deployment endpoint.
2525+
local: bool.
2526+
Whether to invoke the prediction locally. Default to False.
25102527
kwargs:
25112528
content_type: str, used to indicate the media type of the resource.
25122529
image: PIL.Image Object or uri for the image.
@@ -2525,10 +2542,21 @@ def predict(
25252542
NotActiveDeploymentError
25262543
If model deployment process was not started or not finished yet.
25272544
ValueError
2528-
If `data` is empty or not JSON serializable.
2545+
If model is not deployed yet or the endpoint information is not available.
25292546
"""
2530-
if not self.model_deployment:
2531-
raise ValueError("Use `deploy()` method to start model deployment.")
2547+
if local:
2548+
return self.verify(
2549+
data=data, auto_serialize_data=auto_serialize_data, **kwargs
2550+
)
2551+
2552+
if not (self.model_deployment and self.model_deployment.url):
2553+
raise ValueError(
2554+
"Error invoking the remote endpoint as the model is not "
2555+
"deployed yet or the endpoint information is not available. "
2556+
"Use `deploy()` method to start model deployment. "
2557+
"If you intend to invoke inference using locally available "
2558+
"model artifact, set parameter `local=True`"
2559+
)
25322560

25332561
current_state = self.model_deployment.state.name.upper()
25342562
if current_state != ModelDeploymentState.ACTIVE.name:

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ xlrd>=1.2.0
1313
lxml
1414
fastparquet
1515
imbalanced-learn
16-
pyarrow
16+
pyarrow
17+
mysql-connector-python

docs/source/user_guide/model_registration/model_artifact.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Auto generation of ``score.py`` with framework specific code for loading models
3030

3131
To accomodate for other frameworks that are unknown to ADS, a template code for ``score.py`` is generated in the provided artificat directory location.
3232

33+
3334
Prepare the Model Artifact
3435
--------------------------
3536

@@ -98,8 +99,25 @@ ADS automatically captures:
9899
* ``UseCaseType`` in ``metadata_taxonomy`` cannot be automatically populated. One way to populate the use case is to pass ``use_case_type`` to the ``prepare`` method.
99100
* Model introspection is automatically triggered.
100101

101-
.. include:: _template/score.rst
102+
Prepare with custom ``score.py``
103+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104+
105+
.. versionadded:: 2.8.4
102106

107+
You could provide the location of your own ``score.py`` by ``score_py_uri`` in :py:meth:`~ads.model.GenericModel.prepare`.
108+
The provided ``score.py`` will be added into model artifact.
109+
110+
.. code-block:: python3
111+
112+
tf_model.prepare(
113+
inference_conda_env="generalml_p38_cpu_v1",
114+
use_case_type=UseCaseType.MULTINOMIAL_CLASSIFICATION,
115+
X_sample=trainx,
116+
y_sample=trainy,
117+
score_py_uri="/path/to/score.py"
118+
)
119+
120+
.. include:: _template/score.rst
103121

104122
Model Introspection
105123
-------------------

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
"nbformat",
7070
"inflection",
7171
],
72-
"mysql": ["mysql-connector-python"],
7372
"bds": ["ibis-framework[impala]", "hdfs[kerberos]", "sqlalchemy"],
7473
"spark": ["pyspark>=3.0.0"],
7574
"huggingface": ["transformers"],
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# THIS IS A CUSTOM SCORE.PY
2+
3+
model_name = "model.pkl"
4+
5+
6+
def load_model(model_file_name=model_name):
7+
return model_file_name
8+
9+
10+
def predict(data, model=load_model()):
11+
return {"prediction": "This is a custom score.py."}

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,30 @@
168168
"training_script": None,
169169
}
170170

171+
INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172+
TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173+
DEFAULT_PYTHON_VERSION = "3.8"
174+
MODEL_FILE_NAME = "fake_model_name"
175+
FAKE_MD_URL = "http://<model-deployment-url>"
176+
177+
178+
def _prepare(model):
179+
model.prepare(
180+
inference_conda_env=INFERENCE_CONDA_ENV,
181+
inference_python_version=DEFAULT_PYTHON_VERSION,
182+
training_conda_env=TRAINING_CONDA_ENV,
183+
training_python_version=DEFAULT_PYTHON_VERSION,
184+
model_file_name=MODEL_FILE_NAME,
185+
force_overwrite=True,
186+
)
187+
171188

172189
class TestEstimator:
173190
def predict(self, x):
174191
return x**2
175192

176193

177194
class TestGenericModel:
178-
179195
iris = load_iris()
180196
X, y = iris.data, iris.target
181197
X_train, X_test, y_train, y_test = train_test_split(X, y)
@@ -298,16 +314,22 @@ def test_prepare_both_conda_env(self, mock_signer):
298314
)
299315

300316
@patch("ads.common.auth.default_signer")
301-
def test_verify_without_reload(self, mock_signer):
302-
"""Test verify input data without reload artifacts."""
317+
def test_prepare_with_custom_scorepy(self, mock_signer):
318+
"""Test prepare a trained model with custom score.py."""
303319
self.generic_model.prepare(
304-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
305-
inference_python_version="3.6",
306-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
307-
training_python_version="3.7",
320+
INFERENCE_CONDA_ENV,
308321
model_file_name="fake_model_name",
309-
force_overwrite=True,
322+
score_py_uri=f"{os.path.dirname(os.path.abspath(__file__))}/test_files/custom_score.py",
310323
)
324+
assert os.path.exists(os.path.join("fake_folder", "score.py"))
325+
326+
prediction = self.generic_model.verify(data="test")["prediction"]
327+
assert prediction == "This is a custom score.py."
328+
329+
@patch("ads.common.auth.default_signer")
330+
def test_verify_without_reload(self, mock_signer):
331+
"""Test verify input data without reload artifacts."""
332+
_prepare(self.generic_model)
311333
self.generic_model.verify(self.X_test.tolist())
312334

313335
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
@@ -317,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
317339
@patch("ads.common.auth.default_signer")
318340
def test_verify(self, mock_signer):
319341
"""Test verify input data"""
320-
self.generic_model.prepare(
321-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
322-
inference_python_version="3.6",
323-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
324-
training_python_version="3.7",
325-
model_file_name="fake_model_name",
326-
force_overwrite=True,
327-
)
342+
_prepare(self.generic_model)
328343
prediction_1 = self.generic_model.verify(self.X_test.tolist())
329344
assert isinstance(prediction_1, dict), "Failed to verify json payload."
330345

331-
prediction_2 = self.generic_model.verify(self.X_test.tolist())
332-
assert isinstance(prediction_2, dict), "Failed to verify input data."
333-
334346
def test_reload(self):
335347
"""test the reload."""
336348
pass
@@ -622,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
622634
== random_name[:-9]
623635
)
624636

637+
@pytest.mark.parametrize("input_data", [(X_test.tolist())])
638+
@patch("ads.common.auth.default_signer")
639+
def test_predict_locally(self, mock_signer, input_data):
640+
_prepare(self.generic_model)
641+
test_result = self.generic_model.predict(data=input_data, local=True)
642+
expected_result = self.generic_model.estimator.predict(input_data).tolist()
643+
assert (
644+
test_result["prediction"] == expected_result
645+
), "Failed to verify input data."
646+
647+
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
648+
self.generic_model.predict(
649+
data=input_data, local=True, reload_artifacts=False
650+
)
651+
mock_reload.assert_not_called()
652+
625653
@patch.object(ModelDeployment, "predict")
626654
@patch("ads.common.auth.default_signer")
627655
@patch("ads.common.oci_client.OCIClientFactory")
656+
@patch(
657+
"ads.model.deployment.model_deployment.ModelDeployment.url",
658+
return_value=FAKE_MD_URL,
659+
)
628660
def test_predict_with_not_active_deployment_fail(
629-
self, mock_client, mock_signer, mock_predict
661+
self, mock_url, mock_client, mock_signer, mock_predict
630662
):
631663
"""Ensures predict model fails in case of model deployment is not in an active state."""
632664
with pytest.raises(NotActiveDeploymentError):
@@ -646,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(
646678

647679
@patch("ads.common.auth.default_signer")
648680
@patch("ads.common.oci_client.OCIClientFactory")
649-
def test_predict_bytes_success(self, mock_client, mock_signer):
681+
@patch(
682+
"ads.model.deployment.model_deployment.ModelDeployment.url",
683+
return_value=FAKE_MD_URL,
684+
)
685+
def test_predict_bytes_success(self, mock_url, mock_client, mock_signer):
650686
"""Ensures predict model passes with bytes input."""
651687
with patch.object(
652688
ModelDeployment, "state", new_callable=PropertyMock
@@ -655,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
655691
with patch.object(ModelDeployment, "predict") as mock_predict:
656692
mock_predict.return_value = {"result": "result"}
657693
self.generic_model.model_deployment = ModelDeployment(
658-
model_deployment_id="test"
694+
model_deployment_id="test",
659695
)
660696
# self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
661697
self.generic_model._as_onnx = False
@@ -668,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
668704

669705
@patch("ads.common.auth.default_signer")
670706
@patch("ads.common.oci_client.OCIClientFactory")
671-
def test_predict_success(self, mock_client, mock_signer):
707+
@patch(
708+
"ads.model.deployment.model_deployment.ModelDeployment.url",
709+
return_value=FAKE_MD_URL,
710+
)
711+
def test_predict_success(self, mock_url, mock_client, mock_signer):
672712
"""Ensures predict model passes with valid input parameters."""
673713
with patch.object(
674714
ModelDeployment, "state", new_callable=PropertyMock
@@ -785,7 +825,11 @@ def test_from_model_artifact(
785825

786826
@patch("ads.common.auth.default_signer")
787827
@patch("ads.common.oci_client.OCIClientFactory")
788-
def test_predict_success__serialize_input(self, mock_client, mock_signer):
828+
@patch(
829+
"ads.model.deployment.model_deployment.ModelDeployment.url",
830+
return_value=FAKE_MD_URL,
831+
)
832+
def test_predict_success__serialize_input(self, mock_url, mock_client, mock_signer):
789833
"""Ensures predict model passes with valid input parameters."""
790834

791835
df = pd.DataFrame([1, 2, 3])
@@ -795,7 +839,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
795839
with patch.object(
796840
GenericModel, "get_data_serializer"
797841
) as mock_get_data_serializer:
798-
799842
mock_get_data_serializer.return_value.data = df.to_json()
800843
mock_state.return_value = ModelDeploymentState.ACTIVE
801844
with patch.object(ModelDeployment, "predict") as mock_predict:
@@ -1782,7 +1825,6 @@ def test_upload_artifact_fail(self):
17821825
def test_upload_artifact_success(self):
17831826
"""Tests uploading model artifacts to the provided `uri`."""
17841827
with tempfile.TemporaryDirectory() as tmp_dir:
1785-
17861828
# copy test artifacts to the temp folder
17871829
shutil.copytree(
17881830
os.path.join(self.curr_dir, "test_files/valid_model_artifacts"),

0 commit comments

Comments
 (0)