Skip to content

Commit 76be46c

Browse files
Rebasing and fixing UTs
1 parent 99970fa commit 76be46c

File tree

3 files changed

+25
-22
lines changed

3 files changed

+25
-22
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1212
from ads.aqua.common.utils import (
13-
get_container_config,
1413
get_hf_model_info,
1514
list_hf_models,
1615
)
1716
from ads.aqua.extension.base_handler import AquaAPIhandler
1817
from ads.aqua.extension.errors import Errors
1918
from ads.aqua.model import AquaModelApp
2019
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
21-
from ads.aqua.ui import AquaContainerConfig, ModelFormat
20+
from ads.aqua.ui import ModelFormat
2221

2322

2423
class AquaModelHandler(AquaAPIhandler):
@@ -154,14 +153,11 @@ def put(self, id):
154153
raise HTTPError(400, Errors.NO_INPUT_DATA)
155154

156155
inference_container = input_data.get("inference_container")
157-
containers = list(
158-
AquaContainerConfig.from_container_index_json(
159-
config=get_container_config(), enable_spec=True
160-
).inference.values()
161-
)
162-
family_values = [item.family for item in containers]
163-
164-
if inference_container is not None and inference_container not in family_values:
156+
inference_containers = AquaModelApp.list_valid_inference_containers()
157+
if (
158+
inference_container is not None
159+
and inference_container not in inference_containers
160+
):
165161
raise HTTPError(
166162
400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container")
167163
)

ads/aqua/model/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
copy_model_config,
2727
create_word_icon,
2828
get_artifact_path,
29+
get_container_config,
2930
get_hf_model_info,
3031
list_os_files_with_extension,
3132
load_config,
@@ -718,6 +719,16 @@ def clear_model_list_cache(
718719
}
719720
return res
720721

722+
@staticmethod
723+
def list_valid_inference_containers():
724+
containers = list(
725+
AquaContainerConfig.from_container_index_json(
726+
config=get_container_config(), enable_spec=True
727+
).inference.values()
728+
)
729+
family_values = [item.family for item in containers]
730+
return family_values
731+
721732
def _create_model_catalog_entry(
722733
self,
723734
os_path: str,

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
)
2121
from ads.aqua.model import AquaModelApp
2222
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
23-
from ads.aqua.ui import AquaContainerConfig
2423

2524

2625
class ModelHandlerTestCase(TestCase):
@@ -94,18 +93,15 @@ def test_delete_with_id(self, mock_delete, mock_urlparse):
9493
mock_urlparse.assert_called()
9594
mock_delete.assert_called()
9695

97-
@patch.object(AquaContainerConfig, "from_container_index_json")
96+
@patch.object(AquaModelApp, "list_valid_inference_containers")
9897
@patch.object(AquaModelApp, "edit_registered_model")
99-
def test_put(self, mock_edit, mock_container_index):
98+
def test_put(self, mock_edit, mock_inference_container_list):
10099
mock_edit.return_value = {"state": "EDITED"}
101-
mock_inference = MagicMock()
102-
mock_inference.values.return_value = [
103-
MagicMock(family="odsc-vllm-serving"),
104-
MagicMock(family="odsc-tgi-serving"),
105-
MagicMock(family="odsc-vllm-serving"),
100+
mock_inference_container_list.return_value = [
101+
"odsc-vllm-serving",
102+
"odsc-tgi-serving",
103+
"odsc-llama-cpp-serving",
106104
]
107-
108-
mock_container_index.return_value = MagicMock(inference=mock_inference)
109105
self.model_handler.get_json_body = MagicMock(
110106
return_value=dict(
111107
task="text_generation",
@@ -118,9 +114,9 @@ def test_put(self, mock_edit, mock_container_index):
118114
) as mock_finish:
119115
mock_finish.side_effect = lambda x: x
120116
result = self.model_handler.put(id="ocid1.datasciencemodel.oc1.iad.xxx")
121-
print(f"result: ", result)
122117
assert result["state"] is "EDITED"
123-
mock_edit.assert_called()
118+
mock_edit.assert_called_once()
119+
mock_inference_container_list.assert_called_once()
124120

125121
@patch.object(AquaModelApp, "list")
126122
def test_list(self, mock_list):

0 commit comments

Comments
 (0)