Skip to content

Commit de53f9e

Browse files
Merge branch 'feature/aquav1.0.1' into update_telemetry_for_model_shapes
2 parents c1da982 + 1c34cbb commit de53f9e

File tree

7 files changed

+103
-61
lines changed

7 files changed

+103
-61
lines changed

ads/aqua/__init__.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66

7-
import logging
87
import os
9-
import sys
108

11-
from ads import set_auth
9+
from ads import logger, set_auth
1210
from ads.aqua.utils import fetch_service_compartment
13-
from ads.config import NB_SESSION_OCID, OCI_RESOURCE_PRINCIPAL_VERSION
11+
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1412

1513
ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL"
1614

@@ -21,25 +19,7 @@ def get_logger_level():
2119
return level
2220

2321

24-
def configure_aqua_logger():
25-
"""Configures the AQUA logger."""
26-
log_level = get_logger_level()
27-
logger = logging.getLogger(__name__)
28-
logger.setLevel(log_level)
29-
30-
handler = logging.StreamHandler(sys.stdout)
31-
formatter = logging.Formatter(
32-
"%(asctime)s - %(name)s.%(module)s - %(levelname)s - %(message)s"
33-
)
34-
handler.setFormatter(formatter)
35-
handler.setLevel(log_level)
36-
37-
logger.addHandler(handler)
38-
logger.propagate = False
39-
return logger
40-
41-
42-
logger = configure_aqua_logger()
22+
logger.setLevel(get_logger_level())
4323

4424

4525
def set_log_level(log_level: str):
@@ -56,9 +36,3 @@ def set_log_level(log_level: str):
5636
ODSC_MODEL_COMPARTMENT_OCID = (
5737
os.environ.get("ODSC_MODEL_COMPARTMENT_OCID") or fetch_service_compartment()
5838
)
59-
if not ODSC_MODEL_COMPARTMENT_OCID:
60-
if NB_SESSION_OCID:
61-
logger.error(
62-
f"Aqua is not available for this notebook session {NB_SESSION_OCID}."
63-
)
64-
exit()

ads/aqua/cli.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66
import os
7-
8-
from ads.aqua import ENV_VAR_LOG_LEVEL, set_log_level
7+
import sys
8+
9+
from ads.aqua import (
10+
ENV_VAR_LOG_LEVEL,
11+
set_log_level,
12+
ODSC_MODEL_COMPARTMENT_OCID,
13+
logger,
14+
)
915
from ads.aqua.deployment import AquaDeploymentApp
1016
from ads.aqua.evaluation import AquaEvaluationApp
1117
from ads.aqua.finetune import AquaFineTuningApp
1218
from ads.aqua.model import AquaModelApp
19+
from ads.config import NB_SESSION_OCID
1320

1421

1522
class AquaCommand:
@@ -41,3 +48,13 @@ def __init__(
4148
'WARNING', 'ERROR', and 'CRITICAL'.
4249
"""
4350
set_log_level(log_level)
51+
# gracefully exit if env var is not set
52+
if not ODSC_MODEL_COMPARTMENT_OCID:
53+
logger.error(
54+
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
55+
)
56+
if NB_SESSION_OCID:
57+
logger.error(
58+
f"Aqua is not available for the notebook session {NB_SESSION_OCID}."
59+
)
60+
sys.exit(1)

ads/aqua/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def fetch_service_compartment() -> Union[str, None]:
582582
)
583583
except AquaFileNotFoundError:
584584
logger.error(
585-
f"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
585+
f"Config file {config_file_name}/{CONTAINER_INDEX} to fetch service compartment OCID could not be found."
586586
)
587587
return
588588
compartment_mapping = config.get(COMPARTMENT_MAPPING_KEY)

ads/cli.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99

1010
import fire
11+
from dataclasses import is_dataclass
1112
from ads.common import logger
1213

1314
try:
@@ -70,11 +71,28 @@ def _SeparateFlagArgs(args):
7071
fire.core.parser.SeparateFlagArgs = _SeparateFlagArgs
7172

7273

74+
def serialize(data):
75+
"""Serialize dataclass objects or lists of dataclass objects.
76+
Parameters:
77+
data: A dataclass object or a list of dataclass objects.
78+
Returns:
79+
None
80+
Prints:
81+
The string representation of each dataclass object.
82+
"""
83+
if isinstance(data, list):
84+
[print(str(item)) for item in data]
85+
else:
86+
print(str(data))
87+
88+
7389
def cli():
7490
if len(sys.argv) > 1 and sys.argv[1] == "aqua":
7591
from ads.aqua.cli import AquaCommand
7692

77-
fire.Fire(AquaCommand, command=sys.argv[2:], name="ads aqua")
93+
fire.Fire(
94+
AquaCommand, command=sys.argv[2:], name="ads aqua", serialize=serialize
95+
)
7896
else:
7997
click_cli()
8098

ads/common/serializer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ def to_json(
195195
`None` in case when `uri` provided.
196196
"""
197197
json_string = json.dumps(
198-
self.to_dict(**kwargs), cls=encoder, default=default or self.serialize
198+
self.to_dict(**kwargs),
199+
cls=encoder,
200+
default=default or self.serialize,
201+
indent=4,
199202
)
200203
if uri:
201204
self._write_to_file(s=json_string, uri=uri, **kwargs)
@@ -463,9 +466,7 @@ def from_dict(
463466
"These fields will be ignored."
464467
)
465468

466-
obj = cls(
467-
**{key: obj_dict.get(key) for key in allowed_fields}
468-
)
469+
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})
469470

470471
for key, value in obj_dict.items():
471472
if (

tests/unitary/with_extras/aqua/test_cli.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,30 @@
44
# Copyright (c) 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import os
78
import logging
89
import subprocess
910
from unittest import TestCase
1011
from unittest.mock import patch
11-
12+
from importlib import reload
1213
from parameterized import parameterized
1314

15+
import ads.aqua
16+
import ads.config
1417
from ads.aqua.cli import AquaCommand
1518

1619

1720
class TestAquaCLI(TestCase):
1821
"""Tests the AQUA CLI."""
1922

20-
DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR"
23+
DEFAULT_AQUA_CLI_LOGGING_LEVEL = "ERROR"
2124
logger = logging.getLogger(__name__)
2225
logging.basicConfig(
2326
format="%(asctime)s %(module)s %(levelname)s: %(message)s",
2427
datefmt="%m/%d/%Y %I:%M:%S %p",
2528
level=logging.INFO,
2629
)
30+
SERVICE_COMPARTMENT_ID = "ocid1.compartment.oc1..<OCID>"
2731

2832
def test_entrypoint(self):
2933
"""Tests CLI entrypoint."""
@@ -33,15 +37,55 @@ def test_entrypoint(self):
3337

3438
@parameterized.expand(
3539
[
36-
("default", None, DEFAUL_AQUA_CLI_LOGGING_LEVEL),
40+
("default", None, DEFAULT_AQUA_CLI_LOGGING_LEVEL),
3741
("set logging level", "info", "info"),
3842
]
3943
)
40-
@patch("ads.aqua.cli.set_log_level")
41-
def test_aquacommand(self, name, arg, expected, mock_setting_log):
42-
"""Tests aqua command initailzation."""
43-
if arg:
44-
AquaCommand(arg)
45-
else:
46-
AquaCommand()
47-
mock_setting_log.assert_called_with(expected)
44+
def test_aquacommand(self, name, arg, expected):
45+
"""Tests aqua command initialization."""
46+
with patch.dict(
47+
os.environ,
48+
{"ODSC_MODEL_COMPARTMENT_OCID": TestAquaCLI.SERVICE_COMPARTMENT_ID},
49+
):
50+
reload(ads.config)
51+
reload(ads.aqua)
52+
reload(ads.aqua.cli)
53+
with patch("ads.aqua.cli.set_log_level") as mock_setting_log:
54+
if arg:
55+
AquaCommand(arg)
56+
else:
57+
AquaCommand()
58+
mock_setting_log.assert_called_with(expected)
59+
60+
@parameterized.expand(
61+
[
62+
("default", None),
63+
("using jupyter instance", "nb-session-ocid"),
64+
]
65+
)
66+
def test_aqua_command_without_compartment_env_var(self, name, session_ocid):
67+
"""Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. Also check if NB_SESSION_OCID is
68+
set then log the appropriate message."""
69+
70+
with patch("sys.exit") as mock_exit:
71+
env_dict = {"ODSC_MODEL_COMPARTMENT_OCID": ""}
72+
if session_ocid:
73+
env_dict.update({"NB_SESSION_OCID": session_ocid})
74+
with patch.dict(os.environ, env_dict):
75+
reload(ads.config)
76+
reload(ads.aqua)
77+
reload(ads.aqua.cli)
78+
with patch("ads.aqua.cli.set_log_level") as mock_setting_log:
79+
with patch("ads.aqua.logger.error") as mock_logger_error:
80+
AquaCommand()
81+
mock_setting_log.assert_called_with(
82+
TestAquaCLI.DEFAULT_AQUA_CLI_LOGGING_LEVEL
83+
)
84+
mock_logger_error.assert_any_call(
85+
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
86+
)
87+
if session_ocid:
88+
mock_logger_error.assert_any_call(
89+
f"Aqua is not available for the notebook session {session_ocid}."
90+
)
91+
mock_exit.assert_called_with(1)

tests/unitary/with_extras/aqua/test_global.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import unittest
77
from unittest.mock import MagicMock, patch
88

9-
from ads.aqua import configure_aqua_logger, get_logger_level, set_log_level
9+
from ads.aqua import get_logger_level, set_log_level
1010

1111

1212
class TestAquaLogging(unittest.TestCase):
@@ -22,18 +22,6 @@ def test_get_logger_level_from_env(self):
2222
"""Test log level is correctly read from environment variable."""
2323
self.assertEqual(get_logger_level(), "DEBUG")
2424

25-
@patch("logging.getLogger")
26-
@patch("logging.StreamHandler")
27-
def test_configure_aqua_logger(self, mock_handler, mock_get_logger):
28-
"""Test that logger is correctly configured."""
29-
mock_logger = MagicMock()
30-
mock_get_logger.return_value = mock_logger
31-
32-
logger = configure_aqua_logger()
33-
34-
mock_get_logger.assert_called_once_with("ads.aqua")
35-
mock_logger.setLevel.assert_called_with(self.DEFAULT_AQUA_LOG_LEVEL)
36-
3725
@patch("ads.aqua.logger", create=True)
3826
def test_set_log_level(self, mock_logger):
3927
"""Test that the log level of the logger is set correctly."""

0 commit comments

Comments
 (0)