Skip to content

Commit 0e20b0a

Browse files
committed
Merge commit '5487f6243528d0b31338484002837507642266c0' from develop
2 parents e3f1d60 + 5487f62 commit 0e20b0a

File tree

21 files changed

+266
-147
lines changed

21 files changed

+266
-147
lines changed

.github/workflows/run-unittests-default_setup.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88
- develop
99
paths:
1010
- "ads/**"
11+
- "!ads/ads_version.json"
1112
- setup.py
1213
- "**requirements.txt"
1314
- .github/workflows/run-unittests.yml

.github/workflows/run-unittests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88
- develop
99
paths:
1010
- "ads/**"
11+
- "!ads/ads_version.json"
1112
- setup.py
1213
- "**requirements.txt"
1314
- .github/workflows/run-unittests.yml

ads/ads_version.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "2.8.3"
2+
"version": "2.8.4"
33
}

ads/dataflow/dataflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class SPARK_VERSION(str):
6363
class DataFlow:
6464
@deprecated(
6565
"2.6.3",
66-
details="Use ads.jobs.DataFlow class for creating DataFlow applications and runs. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/apachespark/dataflow.html#create-run-data-flow-application-using-ads-python-sdk",
66+
details="Use ads.jobs.DataFlow class for creating Data Flow applications and runs. Check https://accelerated-data-science.readthedocs.io/en/latest/user_guide/apachespark/dataflow.html#create-run-data-flow-application-using-ads-python-sdk",
6767
)
6868
def __init__(
6969
self,

ads/opctl/backend/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Backend:
2121

2222
def __init__(self, config: Dict) -> None:
2323
self.config = config
24-
self.auth_type = config["execution"].get("auth")
24+
self.auth_type = config["execution"].get("auth", "api_key")
2525
self.profile = config["execution"].get("oci_profile", None)
2626
self.oci_config = config["execution"].get("oci_config", None)
2727

ads/opctl/backend/local.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from ads.pipeline.ads_pipeline import Pipeline, PipelineStep
5454
from ads.common.oci_client import OCIClientFactory
55+
from ads.config import NO_CONTAINER
5556

5657
class CondaPackNotFound(Exception): # pragma: no cover
5758
pass
@@ -218,6 +219,7 @@ def _run_with_conda_pack(
218219
)
219220
if os.path.exists(os.path.join(conda_pack_path, "spark-defaults.conf")):
220221
env_vars["SPARK_CONF_DIR"] = os.path.join(DEFAULT_IMAGE_CONDA_DIR, slug)
222+
logger.info(f"Running with conda pack in a container with command {command}")
221223
return self._activate_conda_env_and_run(
222224
image, slug, command, bind_volumes, env_vars
223225
)
@@ -679,9 +681,11 @@ def predict(self) -> None:
679681
None
680682
Nothing.
681683
"""
684+
685+
# model artifact in artifact directory
682686
artifact_directory = self.config["execution"].get("artifact_directory")
683687
ocid = self.config["execution"].get("ocid")
684-
data = self.config["execution"].get("payload")
688+
685689
model_folder = os.path.expanduser(
686690
self.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER)
687691
)
@@ -698,79 +702,72 @@ def predict(self) -> None:
698702
)
699703

700704
_download_model(
701-
oci_auth=self.oci_auth,
705+
auth=self.auth_type,
706+
profile=self.profile,
702707
ocid=ocid,
703708
artifact_directory=artifact_directory,
704709
region=region,
705710
bucket_uri=bucket_uri,
706711
timeout=timeout,
707712
force_overwrite=True,
708713
)
709-
conda_slug, conda_path = None, None
710-
if ocid:
714+
715+
# conda
716+
conda_slug, conda_path = self.config["execution"].get("conda_slug"), self.config["execution"].get("conda_path")
717+
if not conda_slug and not conda_path and ocid:
711718
conda_slug, conda_path = self._get_conda_info_from_custom_metadata(ocid)
712-
if not conda_path:
713-
if (
714-
not os.path.exists(artifact_directory)
715-
or len(os.listdir(artifact_directory)) == 0
716-
):
717-
raise ValueError(
718-
f"`artifact_directory` {artifact_directory} does not exist or is empty."
719-
)
719+
if not conda_slug and not conda_path:
720720
conda_slug, conda_path = self._get_conda_info_from_runtime(
721721
artifact_dir=artifact_directory
722722
)
723-
if not conda_path or not conda_slug:
724-
raise ValueError("Conda information cannot be detected.")
725-
compartment_id = self.config["execution"].get(
726-
"compartment_id", self.config["infrastructure"].get("compartment_id")
727-
)
728-
project_id = self.config["execution"].get(
729-
"project_id", self.config["infrastructure"].get("project_id")
730-
)
731-
if not compartment_id or not project_id:
732-
raise ValueError("`compartment_id` and `project_id` must be provided.")
733-
extra_cmd = (
734-
DEFAULT_MODEL_DEPLOYMENT_FOLDER
735-
+ " "
736-
+ data
737-
+ " "
738-
+ compartment_id
739-
+ " "
740-
+ project_id
741-
)
723+
if 'conda_slug' not in self.config["execution"]:
724+
self.config["execution"]["conda_slug"] = conda_path.split("/")[-1] if conda_path else conda_slug
725+
726+
self.config["execution"]["image"] = ML_JOB_IMAGE
727+
728+
# bind_volumnes
742729
bind_volumes = {}
730+
SCRIPT = "script.py"
731+
dir_path = os.path.dirname(os.path.realpath(__file__))
743732
if not is_in_notebook_session():
744733
bind_volumes = {
745734
os.path.expanduser(
746735
os.path.dirname(self.config["execution"]["oci_config"])
747736
): {"bind": os.path.join(DEFAULT_IMAGE_HOME_DIR, ".oci")}
748737
}
749-
dir_path = os.path.dirname(os.path.realpath(__file__))
750-
script = "script.py"
738+
751739
self.config["execution"]["source_folder"] = os.path.abspath(
752740
os.path.join(dir_path, "..")
753741
)
754-
self.config["execution"]["entrypoint"] = script
742+
self.config["execution"]["entrypoint"] = SCRIPT
755743
bind_volumes[artifact_directory] = {"bind": DEFAULT_MODEL_DEPLOYMENT_FOLDER}
756-
if self.config["execution"].get("conda_slug", conda_slug):
757-
self.config["execution"]["image"] = ML_JOB_IMAGE
758-
if not self.config["execution"].get("conda_slug"):
759-
self.config["execution"]["conda_slug"] = conda_slug
760-
self.config["execution"]["slug"] = conda_slug
761-
self.config["execution"]["conda_path"] = conda_path
762-
exit_code = self._run_with_conda_pack(
763-
bind_volumes, extra_cmd, install=True, conda_uri=conda_path
764-
)
744+
745+
# extra cmd
746+
data = self.config["execution"].get("payload")
747+
extra_cmd = f"--payload '{data}' " + f"--auth {self.auth_type} "
748+
if self.auth_type != "resource_principal":
749+
extra_cmd += f"--profile {self.profile}"
750+
751+
if is_in_notebook_session() or NO_CONTAINER:
752+
# _run_with_conda_pack has code to handle notebook session case,
753+
# however, it activate the conda pack and then run the script.
754+
# For the deployment, we just take the current conda env and run it.
755+
# Hence we just handle the notebook case directly here.
756+
script_path = os.path.join(os.path.join(dir_path, ".."), SCRIPT)
757+
cmd = f"python {script_path} " + f"--artifact-directory {artifact_directory} " + extra_cmd
758+
logger.info(f"Running in a notebook or NO_CONTAINER with command {cmd}")
759+
run_command(cmd=cmd, shell=True)
765760
else:
766-
raise ValueError("Either conda pack info or image should be specified.")
767-
768-
if exit_code != 0:
769-
raise RuntimeError(
770-
f"`predict` did not complete successfully. Exit code: {exit_code}. "
771-
f"Run with the --debug argument to view container logs."
772-
)
773-
761+
extra_cmd = f"--artifact-directory {DEFAULT_MODEL_DEPLOYMENT_FOLDER} "+ extra_cmd
762+
exit_code = self._run_with_conda_pack(
763+
bind_volumes, extra_cmd, install=True, conda_uri=conda_path
764+
)
765+
if exit_code != 0:
766+
raise RuntimeError(
767+
f"`predict` did not complete successfully. Exit code: {exit_code}. "
768+
f"Run with the --debug argument to view container logs."
769+
)
770+
774771
def _get_conda_info_from_custom_metadata(self, ocid):
775772
"""
776773
Get conda env info from custom metadata from model catalog.

ads/opctl/cli.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ads.opctl.model.cli
1717
import ads.opctl.spark.cli
1818
from ads.common import auth as authutil
19+
from ads.common.auth import AuthType
1920
from ads.opctl.cmds import activate as activate_cmd
2021
from ads.opctl.cmds import cancel as cancel_cmd
2122
from ads.opctl.cmds import configure as configure_cmd
@@ -29,12 +30,8 @@
2930
from ads.opctl.cmds import run_diagnostics as run_diagnostics_cmd
3031
from ads.opctl.cmds import watch as watch_cmd
3132
from ads.opctl.config.merger import ConfigMerger
32-
from ads.opctl.constants import (
33-
BACKEND_NAME,
34-
DEFAULT_MODEL_FOLDER,
35-
RESOURCE_TYPE,
36-
RUNTIME_TYPE,
37-
)
33+
from ads.opctl.constants import (BACKEND_NAME, DEFAULT_MODEL_FOLDER,
34+
RESOURCE_TYPE, RUNTIME_TYPE)
3835
from ads.opctl.utils import build_image as build_image_cmd
3936
from ads.opctl.utils import publish_image as publish_image_cmd
4037
from ads.opctl.utils import suppress_traceback
@@ -544,6 +541,7 @@ def init(debug: bool, **kwargs: Dict[str, Any]) -> None:
544541
suppress_traceback(debug)(init_cmd)(**kwargs)
545542

546543

544+
@commands.command()
547545
@click.option(
548546
"--ocid",
549547
nargs=1,
@@ -597,7 +595,13 @@ def init(debug: bool, **kwargs: Dict[str, Any]) -> None:
597595
"--conda-slug",
598596
nargs=1,
599597
required=False,
600-
help="The conda env used to load the model and conduct the prediction. This is only used when model id is passed to `ocid` and a local predict is conducted. It should match the inference conda env specified in the runtime.yaml file which is the conda pack being used when conducting real model deployment.",
598+
help="The conda slug used to load the model and conduct the prediction. This is only used when model id is passed to `ocid` and a local predict is conducted. It should match the inference conda env specified in the runtime.yaml file which is the conda pack being used when conducting real model deployment.",
599+
)
600+
@click.option(
601+
"--conda-path",
602+
nargs=1,
603+
required=False,
604+
help="The conda path used to load the model and conduct the prediction. This is only used when model id is passed to `ocid` and a local predict is conducted. It should match the inference conda env specified in the runtime.yaml file which is the conda pack being used when conducting real model deployment.",
601605
)
602606
@click.option(
603607
"--model-version",
@@ -611,10 +615,22 @@ def init(debug: bool, **kwargs: Dict[str, Any]) -> None:
611615
required=False,
612616
help="When the `inference_server='triton'`, the name of the model to invoke. This can only be used when model deployment id is passed in. For the other cases, it will be ignored.",
613617
)
618+
@click.option(
619+
"--auth",
620+
"-a",
621+
help="authentication method",
622+
type=click.Choice(AuthType.values()),
623+
default=None,
624+
)
625+
@click.option(
626+
"--oci-profile",
627+
help="oci profile",
628+
default=None,
629+
)
614630
@click.option("--debug", "-d", help="set debug mode", is_flag=True, default=False)
615631
def predict(**kwargs):
616632
"""
617-
Deactivates a data science service.
633+
Make prediction using the model with the payload.
618634
"""
619635
suppress_traceback(kwargs["debug"])(predict_cmd)(**kwargs)
620636

ads/opctl/model/cmds.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import shutil
33

4-
from ads.common.auth import create_signer
4+
from ads.common.auth import AuthContext
55
from ads.model.datascience_model import DataScienceModel
66
from ads.opctl import logger
7-
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
87
from ads.opctl.config.base import ConfigProcessor
98
from ads.opctl.config.merger import ConfigMerger
9+
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
1010

1111

1212
def download_model(**kwargs):
@@ -15,12 +15,6 @@ def download_model(**kwargs):
1515

1616
auth_type = p.config["execution"].get("auth")
1717
profile = p.config["execution"].get("oci_profile", None)
18-
oci_config = p.config["execution"].get("oci_config", None)
19-
oci_auth = create_signer(
20-
auth_type,
21-
oci_config,
22-
profile,
23-
)
2418
model_folder = os.path.expanduser(
2519
p.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER)
2620
)
@@ -44,7 +38,8 @@ def download_model(**kwargs):
4438
bucket_uri=bucket_uri,
4539
timeout=timeout,
4640
force_overwrite=force_overwrite,
47-
oci_auth=oci_auth,
41+
auth=auth_type,
42+
profile=profile
4843
)
4944
else:
5045
logger.error(f"Model already exists. Set `force_overwrite=True` to overwrite.")
@@ -54,23 +49,24 @@ def download_model(**kwargs):
5449

5550

5651
def _download_model(
57-
ocid, artifact_directory, oci_auth, region, bucket_uri, timeout, force_overwrite
52+
ocid, artifact_directory, region, bucket_uri, timeout, force_overwrite, auth, profile=None
5853
):
5954
os.makedirs(artifact_directory, exist_ok=True)
60-
os.chmod(artifact_directory, 777)
61-
55+
kwargs = {"auth": auth}
56+
if profile:
57+
kwargs["profile"] = profile
6258
try:
63-
dsc_model = DataScienceModel.from_id(ocid)
64-
dsc_model.download_artifact(
65-
target_dir=artifact_directory,
66-
force_overwrite=force_overwrite,
67-
overwrite_existing_artifact=True,
68-
remove_existing_artifact=True,
69-
auth=oci_auth,
70-
region=region,
71-
timeout=timeout,
72-
bucket_uri=bucket_uri,
73-
)
59+
with AuthContext(**kwargs):
60+
dsc_model = DataScienceModel.from_id(ocid)
61+
dsc_model.download_artifact(
62+
target_dir=artifact_directory,
63+
force_overwrite=force_overwrite,
64+
overwrite_existing_artifact=True,
65+
remove_existing_artifact=True,
66+
region=region,
67+
timeout=timeout,
68+
bucket_uri=bucket_uri,
69+
)
7470
except Exception as e:
7571
print(type(e))
7672
shutil.rmtree(artifact_directory, ignore_errors=True)

ads/opctl/script.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,37 @@
1+
import argparse
12
import json
2-
import sys
3-
import tempfile
43

4+
from ads.common.auth import AuthContext
55
from ads.model.generic_model import GenericModel
66

77

8-
def verify(artifact_dir, data, compartment_id, project_id): # pragma: no cover
9-
with tempfile.TemporaryDirectory() as td:
8+
def verify(artifact_dir, payload, auth, profile): # pragma: no cover
9+
kwargs = {"auth": auth}
10+
if profile != 'None':
11+
kwargs["profile"] = profile
12+
with AuthContext(**kwargs):
1013
model = GenericModel.from_model_artifact(
1114
uri=artifact_dir,
1215
artifact_dir=artifact_dir,
1316
force_overwrite=True,
14-
compartment_id=compartment_id,
15-
project_id=project_id,
1617
)
1718

1819
try:
19-
data = json.loads(data)
20+
payload = json.loads(payload)
2021
except:
2122
pass
22-
print(model.verify(data, auto_serialize_data=False))
23+
print(model.verify(payload, auto_serialize_data=False))
2324

2425

2526
def main(): # pragma: no cover
26-
args = sys.argv[1:]
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("--payload", type=str, required=True)
29+
parser.add_argument("--artifact-directory", type=str, required=True)
30+
parser.add_argument("--auth", type=str, required=True)
31+
parser.add_argument("--profile", type=str,required=False)
32+
args = parser.parse_args()
2733
verify(
28-
artifact_dir=args[0], data=args[1], compartment_id=args[2], project_id=args[3]
34+
artifact_dir=args.artifact_directory, payload=args.payload, auth=args.auth, profile=args.profile
2935
)
3036
return 0
3137

0 commit comments

Comments
 (0)