Skip to content

Commit b67bed5

Browse files
add deployment handler tests
1 parent 48c9557 commit b67bed5

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import os
8+
import unittest
9+
from unittest.mock import MagicMock, patch
10+
from importlib import reload
11+
from notebook.base.handlers import IPythonHandler
12+
13+
import ads.config
14+
import ads.aqua
15+
from ads.aqua.extension.deployment_handler import (
16+
AquaDeploymentHandler,
17+
AquaDeploymentInferenceHandler,
18+
)
19+
from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse
20+
21+
22+
class TestDataset:
23+
USER_COMPARTMENT_ID = "ocid1.compartment.oc1..<USER_COMPARTMENT_OCID>"
24+
USER_PROJECT_ID = "ocid1.datascienceproject.oc1.iad.<USER_PROJECT_OCID>"
25+
deployment_request = {
26+
"model_id": "ocid1.datasciencemodel.oc1.iad.<OCID>",
27+
"instance_shape": "VM.GPU.A10.1",
28+
"display_name": "test-deployment-name",
29+
}
30+
inference_request = {
31+
"prompt": "What is 1+1?",
32+
"endpoint": "https://modeldeployment.customer-oci.com/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>/predict",
33+
"model_params": {
34+
"model": "odsc-llm",
35+
"max_tokens": 500,
36+
"temperature": 0.8,
37+
"top_p": 0.8,
38+
"top_k": 10,
39+
},
40+
}
41+
42+
43+
class TestAquaDeploymentHandler(unittest.TestCase):
44+
@patch.object(IPythonHandler, "__init__")
45+
def setUp(self, ipython_init_mock) -> None:
46+
ipython_init_mock.return_value = None
47+
self.deployment_handler = AquaDeploymentHandler(MagicMock(), MagicMock())
48+
self.deployment_handler.request = MagicMock()
49+
self.deployment_handler.finish = MagicMock()
50+
51+
@classmethod
52+
def setUpClass(cls):
53+
os.environ["PROJECT_COMPARTMENT_OCID"] = TestDataset.USER_COMPARTMENT_ID
54+
os.environ["PROJECT_OCID"] = TestDataset.USER_PROJECT_ID
55+
reload(ads.config)
56+
reload(ads.aqua)
57+
reload(ads.aqua.extension.deployment_handler)
58+
59+
@classmethod
60+
def tearDownClass(cls):
61+
os.environ.pop("PROJECT_COMPARTMENT_OCID", None)
62+
os.environ.pop("PROJECT_OCID", None)
63+
reload(ads.config)
64+
reload(ads.aqua)
65+
reload(ads.aqua.extension.deployment_handler)
66+
67+
@patch.object(AquaDeploymentApp, "get_deployment_config")
68+
def test_get_deployment_config(self, mock_get_deployment_config):
69+
"""Test get method to return deployment config"""
70+
self.deployment_handler.request.path = "aqua/deployments/config"
71+
self.deployment_handler.get(id="mock-model-id")
72+
mock_get_deployment_config.assert_called()
73+
74+
@unittest.skip("fix this test after exception handler is updated.")
75+
@patch("ads.aqua.extension.base_handler.AquaAPIhandler.write_error")
76+
def test_get_deployment_config_without_id(self, mock_error):
77+
"""Test get method to return deployment config"""
78+
# todo: exception handler needs to be revisited
79+
self.deployment_handler.request.path = "aqua/deployments/config"
80+
mock_error.return_value = MagicMock(status=400)
81+
result = self.deployment_handler.get(id="")
82+
mock_error.assert_called_once()
83+
assert result["status"] == 400
84+
85+
@patch.object(AquaDeploymentApp, "get")
86+
def test_get_deployment(self, mock_get):
87+
"""Test get method to return deployment information."""
88+
self.deployment_handler.request.path = "aqua/deployments"
89+
self.deployment_handler.get(id="mock-model-id")
90+
mock_get.assert_called()
91+
92+
@patch.object(AquaDeploymentApp, "list")
93+
def test_list_deployment(self, mock_list):
94+
"""Test get method to return a list of model deployments."""
95+
self.deployment_handler.request.path = "aqua/deployments"
96+
self.deployment_handler.get(id="")
97+
mock_list.assert_called_with(
98+
compartment_id=TestDataset.USER_COMPARTMENT_ID, project_id=None
99+
)
100+
101+
@patch.object(AquaDeploymentApp, "create")
102+
def test_post(self, mock_create):
103+
"""Test post method to create a model deployment."""
104+
self.deployment_handler.get_json_body = MagicMock(
105+
return_value=TestDataset.deployment_request
106+
)
107+
108+
self.deployment_handler.post()
109+
mock_create.assert_called_with(
110+
compartment_id=TestDataset.USER_COMPARTMENT_ID,
111+
project_id=TestDataset.USER_PROJECT_ID,
112+
model_id=TestDataset.deployment_request["model_id"],
113+
display_name=TestDataset.deployment_request["display_name"],
114+
description=None,
115+
instance_count=None,
116+
instance_shape=TestDataset.deployment_request["instance_shape"],
117+
log_group_id=None,
118+
access_log_id=None,
119+
predict_log_id=None,
120+
bandwidth_mbps=None,
121+
)
122+
123+
124+
class TestAquaDeploymentInferenceHandler(unittest.TestCase):
125+
@patch.object(IPythonHandler, "__init__")
126+
def setUp(self, ipython_init_mock) -> None:
127+
ipython_init_mock.return_value = None
128+
self.inference_handler = AquaDeploymentInferenceHandler(
129+
MagicMock(), MagicMock()
130+
)
131+
self.inference_handler.request = MagicMock()
132+
self.inference_handler.finish = MagicMock()
133+
134+
@patch.object(MDInferenceResponse, "get_model_deployment_response")
135+
def test_post(self, mock_get_model_deployment_response):
136+
"""Test post method to return model deployment response."""
137+
self.inference_handler.get_json_body = MagicMock(
138+
return_value=TestDataset.inference_request
139+
)
140+
self.inference_handler.post()
141+
mock_get_model_deployment_response.assert_called_with(
142+
TestDataset.inference_request["endpoint"]
143+
)

0 commit comments

Comments
 (0)