4
4
# Copyright (c) 2024 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
+ import os
7
8
import logging
8
9
import subprocess
9
10
from unittest import TestCase
10
11
from unittest .mock import patch
11
-
12
+ from importlib import reload
12
13
from parameterized import parameterized
13
14
15
+ import ads .aqua
16
+ import ads .config
14
17
from ads .aqua .cli import AquaCommand
15
18
16
19
17
20
class TestAquaCLI (TestCase ):
18
21
"""Tests the AQUA CLI."""
19
22
20
- DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR"
23
+ DEFAULT_AQUA_CLI_LOGGING_LEVEL = "ERROR"
21
24
logger = logging .getLogger (__name__ )
22
25
logging .basicConfig (
23
26
format = "%(asctime)s %(module)s %(levelname)s: %(message)s" ,
24
27
datefmt = "%m/%d/%Y %I:%M:%S %p" ,
25
28
level = logging .INFO ,
26
29
)
30
+ SERVICE_COMPARTMENT_ID = "ocid1.compartment.oc1..<OCID>"
27
31
28
32
def test_entrypoint (self ):
29
33
"""Tests CLI entrypoint."""
@@ -33,15 +37,55 @@ def test_entrypoint(self):
33
37
34
38
@parameterized .expand (
35
39
[
36
- ("default" , None , DEFAUL_AQUA_CLI_LOGGING_LEVEL ),
40
+ ("default" , None , DEFAULT_AQUA_CLI_LOGGING_LEVEL ),
37
41
("set logging level" , "info" , "info" ),
38
42
]
39
43
)
40
- @patch ("ads.aqua.cli.set_log_level" )
41
- def test_aquacommand (self , name , arg , expected , mock_setting_log ):
42
- """Tests aqua command initailzation."""
43
- if arg :
44
- AquaCommand (arg )
45
- else :
46
- AquaCommand ()
47
- mock_setting_log .assert_called_with (expected )
44
+ def test_aquacommand (self , name , arg , expected ):
45
+ """Tests aqua command initialization."""
46
+ with patch .dict (
47
+ os .environ ,
48
+ {"ODSC_MODEL_COMPARTMENT_OCID" : TestAquaCLI .SERVICE_COMPARTMENT_ID },
49
+ ):
50
+ reload (ads .config )
51
+ reload (ads .aqua )
52
+ reload (ads .aqua .cli )
53
+ with patch ("ads.aqua.cli.set_log_level" ) as mock_setting_log :
54
+ if arg :
55
+ AquaCommand (arg )
56
+ else :
57
+ AquaCommand ()
58
+ mock_setting_log .assert_called_with (expected )
59
+
60
+ @parameterized .expand (
61
+ [
62
+ ("default" , None ),
63
+ ("using jupyter instance" , "nb-session-ocid" ),
64
+ ]
65
+ )
66
+ def test_aqua_command_without_compartment_env_var (self , name , session_ocid ):
67
+ """Test whether exit is called when ODSC_MODEL_COMPARTMENT_OCID is not set. Also check if NB_SESSION_OCID is
68
+ set then log the appropriate message."""
69
+
70
+ with patch ("sys.exit" ) as mock_exit :
71
+ env_dict = {"ODSC_MODEL_COMPARTMENT_OCID" : "" }
72
+ if session_ocid :
73
+ env_dict .update ({"NB_SESSION_OCID" : session_ocid })
74
+ with patch .dict (os .environ , env_dict ):
75
+ reload (ads .config )
76
+ reload (ads .aqua )
77
+ reload (ads .aqua .cli )
78
+ with patch ("ads.aqua.cli.set_log_level" ) as mock_setting_log :
79
+ with patch ("ads.aqua.logger.error" ) as mock_logger_error :
80
+ AquaCommand ()
81
+ mock_setting_log .assert_called_with (
82
+ TestAquaCLI .DEFAULT_AQUA_CLI_LOGGING_LEVEL
83
+ )
84
+ mock_logger_error .assert_any_call (
85
+ "ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
86
+ )
87
+ if session_ocid :
88
+ mock_logger_error .assert_any_call (
89
+ f"Aqua is not available for the notebook session { session_ocid } ."
90
+ )
91
+ mock_exit .assert_called_with (1 )
0 commit comments