Skip to content

Commit 1c34cbb

Browse files
Graceful exit for aqua cli (#802)
2 parents 7fbf21d + 817f97e commit 1c34cbb

File tree

4 files changed

+76
-21
lines changed

4 files changed

+76
-21
lines changed

ads/aqua/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ads import logger, set_auth
1010
from ads.aqua.utils import fetch_service_compartment
11-
from ads.config import NB_SESSION_OCID, OCI_RESOURCE_PRINCIPAL_VERSION
11+
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1212

1313
ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL"
1414

@@ -36,9 +36,3 @@ def set_log_level(log_level: str):
3636
ODSC_MODEL_COMPARTMENT_OCID = (
3737
os.environ.get("ODSC_MODEL_COMPARTMENT_OCID") or fetch_service_compartment()
3838
)
39-
if not ODSC_MODEL_COMPARTMENT_OCID:
40-
if NB_SESSION_OCID:
41-
logger.error(
42-
f"Aqua is not available for this notebook session {NB_SESSION_OCID}."
43-
)
44-
exit()

ads/aqua/cli.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
import os
7-
8-
from ads.aqua import ENV_VAR_LOG_LEVEL, set_log_level
7+
import sys
8+
9+
from ads.aqua import (
10+
ENV_VAR_LOG_LEVEL,
11+
set_log_level,
12+
ODSC_MODEL_COMPARTMENT_OCID,
13+
logger,
14+
)
915
from ads.aqua.deployment import AquaDeploymentApp
1016
from ads.aqua.evaluation import AquaEvaluationApp
1117
from ads.aqua.finetune import AquaFineTuningApp
1218
from ads.aqua.model import AquaModelApp
19+
from ads.config import NB_SESSION_OCID
1320

1421

1522
class AquaCommand:
@@ -41,3 +48,13 @@ def __init__(
4148
'WARNING', 'ERROR', and 'CRITICAL'.
4249
"""
4350
set_log_level(log_level)
51+
# gracefully exit if env var is not set
52+
if not ODSC_MODEL_COMPARTMENT_OCID:
53+
logger.error(
54+
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
55+
)
56+
if NB_SESSION_OCID:
57+
logger.error(
58+
f"Aqua is not available for the notebook session {NB_SESSION_OCID}."
59+
)
60+
sys.exit(1)

ads/aqua/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def fetch_service_compartment() -> Union[str, None]:
582582
)
583583
except AquaFileNotFoundError:
584584
logger.error(
585-
f"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
585+
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found."
586586
)
587587
return
588588
compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)

tests/unitary/with_extras/aqua/test_cli.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,30 @@
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import os
78
import logging
89
import subprocess
910
from unittest import TestCase
1011
from unittest.mock import patch
11-
12+
from importlib import reload
1213
from parameterized import parameterized
1314

15+
import ads.aqua
16+
import ads.config
1417
from ads.aqua.cli import AquaCommand
1518

1619

1720
class TestAquaCLI(TestCase):
1821
"""Tests the AQUA CLI."""
1922

20-
DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR"
23+
DEFAULT_AQUA_CLI_LOGGING_LEVEL = "ERROR"
2124
logger = logging.getLogger(__name__)
2225
logging.basicConfig(
2326
format="%(asctime)s %(module)s %(levelname)s: %(message)s",
2427
datefmt="%m/%d/%Y %I:%M:%S %p",
2528
level=logging.INFO,
2629
)
30+
SERVICE_COMPARTMENT_ID = "ocid1.compartment.oc1..<OCID>"
2731

2832
def test_entrypoint(self):
2933
"""Tests CLI entrypoint."""
@@ -33,15 +37,55 @@ def test_entrypoint(self):
3337

3438
@parameterized.expand(
3539
[
36-
("default", None, DEFAUL_AQUA_CLI_LOGGING_LEVEL),
40+
("default", None, DEFAULT_AQUA_CLI_LOGGING_LEVEL),
3741
("set logging level", "info", "info"),
3842
]
3943
)
40-
@patch("ads.aqua.cli.set_log_level")
41-
def test_aquacommand(self, name, arg, expected, mock_setting_log):
42-
"""Tests aqua command initailzation."""
43-
if arg:
44-
AquaCommand(arg)
45-
else:
46-
AquaCommand()
47-
mock_setting_log.assert_called_with(expected)
44+
def test_aquacommand(self, name, arg, expected):
45+
"""Tests aqua command initialization."""
46+
with patch.dict(
47+
os.environ,
48+
{"ODSC_MODEL_COMPARTMENT_OCID": TestAquaCLI.SERVICE_COMPARTMENT_ID},
49+
):
50+
reload(ads.config)
51+
reload(ads.aqua)
52+
reload(ads.aqua.cli)
53+
with patch("ads.aqua.cli.set_log_level") as mock_setting_log:
54+
if arg:
55+
AquaCommand(arg)
56+
else:
57+
AquaCommand()
58+
mock_setting_log.assert_called_with(expected)
59+
60+
@parameterized.expand(
61+
[
62+
("default", None),
63+
("using jupyter instance", "nb-session-ocid"),
64+
]
65+
)
66+
def test_aqua_command_without_compartment_env_var(self, name, session_ocid):
67+
"""Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. Also check if NB_SESSION_OCID is
68+
set then log the appropriate message."""
69+
70+
with patch("sys.exit") as mock_exit:
71+
env_dict = {"ODSC_MODEL_COMPARTMENT_OCID": ""}
72+
if session_ocid:
73+
env_dict.update({"NB_SESSION_OCID": session_ocid})
74+
with patch.dict(os.environ, env_dict):
75+
reload(ads.config)
76+
reload(ads.aqua)
77+
reload(ads.aqua.cli)
78+
with patch("ads.aqua.cli.set_log_level") as mock_setting_log:
79+
with patch("ads.aqua.logger.error") as mock_logger_error:
80+
AquaCommand()
81+
mock_setting_log.assert_called_with(
82+
TestAquaCLI.DEFAULT_AQUA_CLI_LOGGING_LEVEL
83+
)
84+
mock_logger_error.assert_any_call(
85+
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
86+
)
87+
if session_ocid:
88+
mock_logger_error.assert_any_call(
89+
f"Aqua is not available for the notebook session {session_ocid}."
90+
)
91+
mock_exit.assert_called_with(1)

0 commit comments

Comments
 (0)