Skip to content

Commit cdd0aa1

Browse files
Merge branch 'main' into auto_transform_bug_fix
2 parents e3ad075 + f61fe2e commit cdd0aa1

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def post(self, *args, **kwargs):
103103
memory_in_gbs = input_data.get("memory_in_gbs")
104104
model_file = input_data.get("model_file")
105105
private_endpoint_id = input_data.get("private_endpoint_id")
106+
container_image_uri = input_data.get("container_image_uri")
107+
cmd_var = input_data.get("cmd_var")
106108

107109
self.finish(
108110
AquaDeploymentApp().create(
@@ -126,6 +128,8 @@ def post(self, *args, **kwargs):
126128
memory_in_gbs=memory_in_gbs,
127129
model_file=model_file,
128130
private_endpoint_id=private_endpoint_id,
131+
container_image_uri=container_image_uri,
132+
cmd_var=cmd_var,
129133
)
130134
)
131135

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def post(self, *args, **kwargs):
123123
download_from_hf = (
124124
str(input_data.get("download_from_hf", "false")).lower() == "true"
125125
)
126+
inference_container_uri = input_data.get("inference_container_uri")
126127

127128
return self.finish(
128129
AquaModelApp().register(
@@ -134,6 +135,7 @@ def post(self, *args, **kwargs):
134135
compartment_id=compartment_id,
135136
project_id=project_id,
136137
model_file=model_file,
138+
inference_container_uri=inference_container_uri,
137139
)
138140
)
139141

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def test_post(self, mock_create):
130130
ocpus=None,
131131
model_file=None,
132132
private_endpoint_id=None,
133+
container_image_uri=None,
134+
cmd_var=None,
133135
)
134136

135137

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from huggingface_hub.hf_api import HfApi, ModelInfo
1111
from huggingface_hub.utils import GatedRepoError
1212
from notebook.base.handlers import IPythonHandler
13+
from parameterized import parameterized
1314

1415
from ads.aqua.common.errors import AquaRuntimeError
1516
from ads.aqua.common.utils import get_hf_model_info
@@ -90,9 +91,25 @@ def test_list(self, mock_list):
9091
compartment_id=None, project_id=None, model_type=None
9192
)
9293

94+
@parameterized.expand(
95+
[
96+
(None, None, False, None),
97+
("odsc-llm-fine-tuning", None, False, None),
98+
(None, "test.gguf", True, None),
99+
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
100+
],
101+
)
93102
@patch("notebook.base.handlers.APIHandler.finish")
94103
@patch("ads.aqua.model.AquaModelApp.register")
95-
def test_register(self, mock_register, mock_finish):
104+
def test_register(
105+
self,
106+
finetuning_container,
107+
model_file,
108+
download_from_hf,
109+
inference_container_uri,
110+
mock_register,
111+
mock_finish,
112+
):
96113
mock_register.return_value = AquaModel(
97114
id="test_id",
98115
inference_container="odsc-tgi-serving",
@@ -105,18 +122,23 @@ def test_register(self, mock_register, mock_finish):
105122
model="test_model_name",
106123
os_path="test_os_path",
107124
inference_container="odsc-tgi-serving",
125+
finetuning_container=finetuning_container,
126+
model_file=model_file,
127+
download_from_hf=download_from_hf,
128+
inference_container_uri=inference_container_uri,
108129
)
109130
)
110131
result = self.model_handler.post()
111132
mock_register.assert_called_with(
112133
model="test_model_name",
113134
os_path="test_os_path",
114135
inference_container="odsc-tgi-serving",
115-
finetuning_container=None,
136+
finetuning_container=finetuning_container,
116137
compartment_id=None,
117138
project_id=None,
118-
model_file=None,
119-
download_from_hf=False,
139+
model_file=model_file,
140+
download_from_hf=download_from_hf,
141+
inference_container_uri=inference_container_uri,
120142
)
121143
assert result["id"] == "test_id"
122144
assert result["inference_container"] == "odsc-tgi-serving"

0 commit comments

Comments
 (0)