Skip to content

Commit 28fdd79

Browse files
add ui handler tests
1 parent b67bed5 commit 28fdd79

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import os
8+
import unittest
9+
from unittest.mock import MagicMock, patch
10+
from importlib import reload
11+
from parameterized import parameterized
12+
13+
import ads.config
14+
import ads.aqua
15+
from notebook.base.handlers import IPythonHandler
16+
from ads.aqua.extension.ui_handler import AquaUIHandler
17+
from ads.aqua.ui import AquaUIApp
18+
from ads.aqua.data import Tags
19+
20+
21+
class TestDataset:
22+
USER_COMPARTMENT_ID = "ocid1.compartment.oc1..<USER_COMPARTMENT_OCID>"
23+
USER_PROJECT_ID = "ocid1.datascienceproject.oc1.iad.<USER_PROJECT_OCID>"
24+
DEPLOYMENT_SHAPE_NAME = "VM.GPU.A10.1"
25+
26+
27+
class TestAquaUIHandler(unittest.TestCase):
28+
@patch.object(IPythonHandler, "__init__")
29+
def setUp(self, ipython_init_mock) -> None:
30+
ipython_init_mock.return_value = None
31+
self.ui_handler = AquaUIHandler(MagicMock(), MagicMock())
32+
self.ui_handler.request = MagicMock()
33+
self.ui_handler.finish = MagicMock()
34+
35+
@classmethod
36+
def setUpClass(cls):
37+
os.environ["PROJECT_COMPARTMENT_OCID"] = TestDataset.USER_COMPARTMENT_ID
38+
os.environ["PROJECT_OCID"] = TestDataset.USER_PROJECT_ID
39+
reload(ads.config)
40+
reload(ads.aqua)
41+
reload(ads.aqua.extension.ui_handler)
42+
43+
@classmethod
44+
def tearDownClass(cls):
45+
os.environ.pop("PROJECT_COMPARTMENT_OCID", None)
46+
os.environ.pop("PROJECT_OCID", None)
47+
reload(ads.config)
48+
reload(ads.aqua)
49+
reload(ads.aqua.extension.ui_handler)
50+
51+
@patch.object(AquaUIApp, "list_log_groups")
52+
def test_list_log_groups(self, mock_list_log_groups):
53+
"""Test get method to fetch log groups"""
54+
self.ui_handler.request.path = "aqua/logging"
55+
self.ui_handler.get(id="")
56+
mock_list_log_groups.assert_called_with(
57+
compartment_id=TestDataset.USER_COMPARTMENT_ID
58+
)
59+
60+
@patch.object(AquaUIApp, "list_logs")
61+
def test_list_logs(self, mock_list_logs):
62+
"""Test get method to fetch logs for a given log group."""
63+
self.ui_handler.request.path = "aqua/logging"
64+
self.ui_handler.get(id="mock-log-id")
65+
mock_list_logs.assert_called_with(log_group_id="mock-log-id")
66+
67+
@patch.object(AquaUIApp, "list_compartments")
68+
def test_list_compartments(self, mock_list_compartments):
69+
"""Test get method to fetch logs for a given log group."""
70+
self.ui_handler.request.path = "aqua/compartments"
71+
self.ui_handler.get()
72+
mock_list_compartments.assert_called()
73+
74+
@patch.object(AquaUIApp, "get_default_compartment")
75+
def test_get_default_compartment(self, mock_get_default_compartment):
76+
"""Test get method to fetch logs for a given log group."""
77+
self.ui_handler.request.path = "aqua/compartments/default"
78+
self.ui_handler.get()
79+
mock_get_default_compartment.assert_called()
80+
81+
@patch.object(AquaUIApp, "list_model_version_sets")
82+
def test_list_experiments(self, mock_list_experiments):
83+
"""Test get method to fetch logs for a given log group."""
84+
self.ui_handler.request.path = "aqua/experiment"
85+
self.ui_handler.get()
86+
mock_list_experiments.assert_called_with(
87+
compartment_id=TestDataset.USER_COMPARTMENT_ID,
88+
target_tag=Tags.AQUA_EVALUATION.value,
89+
)
90+
91+
@patch.object(AquaUIApp, "list_model_version_sets")
92+
def test_list_model_version_sets(self, mock_list_model_version_sets):
93+
"""Test get method to fetch logs for a given log group."""
94+
self.ui_handler.request.path = "aqua/versionsets"
95+
self.ui_handler.get()
96+
mock_list_model_version_sets.assert_called_with(
97+
compartment_id=TestDataset.USER_COMPARTMENT_ID,
98+
target_tag=Tags.AQUA_FINE_TUNING.value,
99+
)
100+
101+
@parameterized.expand(["true", ""])
102+
@patch.object(AquaUIApp, "list_buckets")
103+
def test_list_buckets(self, versioned, mock_list_buckets):
104+
"""Test get method to fetch logs for a given log group."""
105+
self.ui_handler.request.path = "aqua/buckets"
106+
args = {"versioned": versioned}
107+
self.ui_handler.get_argument = MagicMock(
108+
side_effect=lambda arg, default=None: args.get(arg, default)
109+
)
110+
self.ui_handler.get()
111+
mock_list_buckets.assert_called_with(
112+
compartment_id=TestDataset.USER_COMPARTMENT_ID,
113+
versioned=True if versioned == "true" else False,
114+
)
115+
116+
@patch.object(AquaUIApp, "list_job_shapes")
117+
def test_list_job_shapes(self, mock_list_job_shapes):
118+
"""Test get method to fetch logs for a given log group."""
119+
self.ui_handler.request.path = "aqua/job/shapes"
120+
self.ui_handler.get()
121+
mock_list_job_shapes.assert_called_with(
122+
compartment_id=TestDataset.USER_COMPARTMENT_ID
123+
)
124+
125+
@patch.object(AquaUIApp, "list_vcn")
126+
def test_list_vcn(self, mock_list_vcn):
127+
"""Test get method to fetch logs for a given log group."""
128+
self.ui_handler.request.path = "aqua/vcn"
129+
self.ui_handler.get()
130+
mock_list_vcn.assert_called_with(compartment_id=TestDataset.USER_COMPARTMENT_ID)
131+
132+
@patch.object(AquaUIApp, "list_subnets")
133+
def test_mock_list_subnets(self, mock_list_subnets):
134+
"""Test get method to fetch logs for a given log group."""
135+
self.ui_handler.request.path = "aqua/subnets"
136+
args = {"vcn_id": "mock-vcn-id"}
137+
self.ui_handler.get_argument = MagicMock(
138+
side_effect=lambda arg, default=None: args.get(arg, default)
139+
)
140+
self.ui_handler.get()
141+
mock_list_subnets.assert_called_with(
142+
compartment_id=TestDataset.USER_COMPARTMENT_ID, vcn_id="mock-vcn-id"
143+
)
144+
145+
@patch.object(AquaUIApp, "get_shape_availability")
146+
def test_get_shape_availability(self, mock_get_shape_availability):
147+
"""Test get method to fetch logs for a given log group."""
148+
self.ui_handler.request.path = "aqua/shapes/limit"
149+
args = {"instance_shape": TestDataset.DEPLOYMENT_SHAPE_NAME}
150+
self.ui_handler.get_argument = MagicMock(
151+
side_effect=lambda arg, default=None: args.get(arg, default)
152+
)
153+
self.ui_handler.get()
154+
mock_get_shape_availability.assert_called_with(
155+
compartment_id=TestDataset.USER_COMPARTMENT_ID,
156+
instance_shape=TestDataset.DEPLOYMENT_SHAPE_NAME,
157+
)
158+
159+
@patch.object(AquaUIApp, "is_bucket_versioned")
160+
def test_is_bucket_versioned(self, mock_is_bucket_versioned):
161+
"""Test get method to fetch logs for a given log group."""
162+
self.ui_handler.request.path = "aqua/bucket/versioning"
163+
args = {"bucket_uri": "oci://<bucket_name>@<namespace>/<prefix>"}
164+
self.ui_handler.get_argument = MagicMock(
165+
side_effect=lambda arg, default=None: args.get(arg, default)
166+
)
167+
self.ui_handler.get()
168+
mock_is_bucket_versioned.assert_called_with(bucket_uri=args["bucket_uri"])

0 commit comments

Comments
 (0)