Skip to content

Commit 51403a2

Browse files
committed
adding auth
1 parent 10887d7 commit 51403a2

File tree

6 files changed

+66
-54
lines changed

6 files changed

+66
-54
lines changed

ads/opctl/backend/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Backend:
2121

2222
def __init__(self, config: Dict) -> None:
2323
self.config = config
24-
self.auth_type = config["execution"].get("auth")
24+
self.auth_type = config["execution"].get("auth", "api_key")
2525
self.profile = config["execution"].get("oci_profile", None)
2626
self.oci_config = config["execution"].get("oci_config", None)
2727

ads/opctl/backend/local.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def _run_with_conda_pack(
219219
)
220220
if os.path.exists(os.path.join(conda_pack_path, "spark-defaults.conf")):
221221
env_vars["SPARK_CONF_DIR"] = os.path.join(DEFAULT_IMAGE_CONDA_DIR, slug)
222+
logger.info(f"Running with conda pack in a container with command {command}")
222223
return self._activate_conda_env_and_run(
223224
image, slug, command, bind_volumes, env_vars
224225
)
@@ -701,7 +702,8 @@ def predict(self) -> None:
701702
)
702703

703704
_download_model(
704-
oci_auth=self.oci_auth,
705+
auth=self.auth_type,
706+
profile=self.profile,
705707
ocid=ocid,
706708
artifact_directory=artifact_directory,
707709
region=region,
@@ -724,32 +726,37 @@ def predict(self) -> None:
724726
# bind_volumnes
725727
bind_volumes = {}
726728
SCRIPT = "script.py"
729+
dir_path = os.path.dirname(os.path.realpath(__file__))
727730
if not is_in_notebook_session():
728731
bind_volumes = {
729732
os.path.expanduser(
730733
os.path.dirname(self.config["execution"]["oci_config"])
731734
): {"bind": os.path.join(DEFAULT_IMAGE_HOME_DIR, ".oci")}
732735
}
733-
dir_path = os.path.dirname(os.path.realpath(__file__))
734-
736+
735737
self.config["execution"]["source_folder"] = os.path.abspath(
736738
os.path.join(dir_path, "..")
737739
)
738740
self.config["execution"]["entrypoint"] = SCRIPT
739741
bind_volumes[artifact_directory] = {"bind": DEFAULT_MODEL_DEPLOYMENT_FOLDER}
740742

741-
# payload
743+
# extra cmd
742744
data = self.config["execution"].get("payload")
745+
extra_cmd = f"--payload '{data}' " + f"--auth {self.auth_type} "
746+
if self.auth_type != "resource_principal":
747+
extra_cmd += f"--profile {self.profile}"
743748

744749
if is_in_notebook_session() or NO_CONTAINER:
745-
script_path = os.path.join(self.config['execution']['source_folder'], SCRIPT)
746-
run_command(cmd=f"python {script_path} " + f"{artifact_directory} "+ f"'{data}'", shell=True)
750+
# _run_with_conda_pack has code to handle notebook session case,
751+
# however, it activate the conda pack and then run the script.
752+
# For the deployment, we just take the current conda env and run it.
753+
# Hence we just handle the notebook case directly here.
754+
script_path = os.path.join(os.path.join(dir_path, ".."), SCRIPT)
755+
cmd = f"python {script_path} " + f"--artifact-directory {artifact_directory} " + extra_cmd
756+
logger.info(f"Running in a notebook or NO_CONTAINER with command {cmd}")
757+
run_command(cmd=cmd, shell=True)
747758
else:
748-
extra_cmd = (
749-
DEFAULT_MODEL_DEPLOYMENT_FOLDER
750-
+ " "
751-
+ data
752-
)
759+
extra_cmd = f"--artifact-directory {DEFAULT_MODEL_DEPLOYMENT_FOLDER} "+ extra_cmd
753760
exit_code = self._run_with_conda_pack(
754761
bind_volumes, extra_cmd, install=True, conda_uri=conda_path
755762
)
@@ -758,7 +765,7 @@ def predict(self) -> None:
758765
f"`predict` did not complete successfully. Exit code: {exit_code}. "
759766
f"Run with the --debug argument to view container logs."
760767
)
761-
768+
762769
def _get_conda_info_from_custom_metadata(self, ocid):
763770
"""
764771
Get conda env info from custom metadata from model catalog.

ads/opctl/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,11 @@ def init(debug: bool, **kwargs: Dict[str, Any]) -> None:
622622
type=click.Choice(AuthType.values()),
623623
default=None,
624624
)
625+
@click.option(
626+
"--oci-profile",
627+
help="oci profile",
628+
default=None,
629+
)
625630
@click.option("--debug", "-d", help="set debug mode", is_flag=True, default=False)
626631
def predict(**kwargs):
627632
"""

ads/opctl/model/cmds.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import shutil
33

4-
from ads.common.auth import create_signer
4+
from ads.common.auth import AuthContext
55
from ads.model.datascience_model import DataScienceModel
66
from ads.opctl import logger
7-
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
87
from ads.opctl.config.base import ConfigProcessor
98
from ads.opctl.config.merger import ConfigMerger
9+
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
1010

1111

1212
def download_model(**kwargs):
@@ -15,12 +15,6 @@ def download_model(**kwargs):
1515

1616
auth_type = p.config["execution"].get("auth")
1717
profile = p.config["execution"].get("oci_profile", None)
18-
oci_config = p.config["execution"].get("oci_config", None)
19-
oci_auth = create_signer(
20-
auth_type,
21-
oci_config,
22-
profile,
23-
)
2418
model_folder = os.path.expanduser(
2519
p.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER)
2620
)
@@ -44,7 +38,8 @@ def download_model(**kwargs):
4438
bucket_uri=bucket_uri,
4539
timeout=timeout,
4640
force_overwrite=force_overwrite,
47-
oci_auth=oci_auth,
41+
auth=auth_type,
42+
profile=profile
4843
)
4944
else:
5045
logger.error(f"Model already exists. Set `force_overwrite=True` to overwrite.")
@@ -54,22 +49,22 @@ def download_model(**kwargs):
5449

5550

5651
def _download_model(
57-
ocid, artifact_directory, oci_auth, region, bucket_uri, timeout, force_overwrite
52+
ocid, artifact_directory, region, bucket_uri, timeout, force_overwrite, auth, profile
5853
):
5954
os.makedirs(artifact_directory, exist_ok=True)
6055

6156
try:
62-
dsc_model = DataScienceModel.from_id(ocid)
63-
dsc_model.download_artifact(
64-
target_dir=artifact_directory,
65-
force_overwrite=force_overwrite,
66-
overwrite_existing_artifact=True,
67-
remove_existing_artifact=True,
68-
auth=oci_auth,
69-
region=region,
70-
timeout=timeout,
71-
bucket_uri=bucket_uri,
72-
)
57+
with AuthContext(auth=auth, profile=profile):
58+
dsc_model = DataScienceModel.from_id(ocid)
59+
dsc_model.download_artifact(
60+
target_dir=artifact_directory,
61+
force_overwrite=force_overwrite,
62+
overwrite_existing_artifact=True,
63+
remove_existing_artifact=True,
64+
region=region,
65+
timeout=timeout,
66+
bucket_uri=bucket_uri,
67+
)
7368
except Exception as e:
7469
print(type(e))
7570
shutil.rmtree(artifact_directory, ignore_errors=True)

ads/opctl/script.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,37 @@
1+
import argparse
12
import json
2-
import sys
33

4+
from ads.common.auth import AuthContext
45
from ads.model.generic_model import GenericModel
56

67

7-
def verify(artifact_dir, data): # pragma: no cover
8+
def verify(artifact_dir, payload, auth, profile): # pragma: no cover
9+
kwargs = {"auth": auth}
10+
if profile != 'None':
11+
kwargs["profile"] = profile
12+
with AuthContext(**kwargs):
13+
model = GenericModel.from_model_artifact(
14+
uri=artifact_dir,
15+
artifact_dir=artifact_dir,
16+
force_overwrite=True,
17+
)
818

9-
model = GenericModel.from_model_artifact(
10-
uri=artifact_dir,
11-
artifact_dir=artifact_dir,
12-
force_overwrite=True,
13-
)
14-
15-
try:
16-
data = json.loads(data)
17-
except:
18-
pass
19-
print(model.verify(data, auto_serialize_data=False))
19+
try:
20+
payload = json.loads(payload)
21+
except:
22+
pass
23+
print(model.verify(payload, auto_serialize_data=False))
2024

2125

2226
def main(): # pragma: no cover
23-
args = sys.argv[1:]
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument("--payload", type=str, required=True)
29+
parser.add_argument("--artifact-directory", type=str, required=True)
30+
parser.add_argument("--auth", type=str, required=True)
31+
parser.add_argument("--profile", type=str,required=False)
32+
args = parser.parse_args()
2433
verify(
25-
artifact_dir=args[0], data=args[1]
34+
artifact_dir=args.artifact_directory, payload=args.payload, auth=args.auth, profile=args.profile
2635
)
2736
return 0
2837

tests/unitary/with_extras/opctl/test_opctl_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from unittest.mock import ANY, call, patch
44
from ads.model.datascience_model import DataScienceModel
55
from unittest.mock import MagicMock, Mock
6-
from ads.opctl.model.cmds import create_signer
76
import os
87

98

@@ -36,12 +35,9 @@ def test_model__download_model_error(mock_from_id):
3635

3736

3837
@patch("ads.opctl.model.cmds._download_model")
39-
@patch("ads.opctl.model.cmds.create_signer")
40-
def test_download_model(mock_create_signer, mock__download_model):
38+
def test_download_model( mock__download_model):
4139
auth_mock = MagicMock()
42-
mock_create_signer.return_value = auth_mock
4340
download_model(ocid="fake_model_id")
44-
mock_create_signer.assert_called_once()
4541
mock__download_model.assert_called_once_with(
4642
ocid="fake_model_id",
4743
artifact_directory=os.path.expanduser("~/.ads_ops/models/fake_model_id"),

0 commit comments

Comments
 (0)