Skip to content

Commit bfd32d4

Browse files
Merge branch 'main' into feature/aqua-v1.0.5
2 parents 7c019ca + ef6184d commit bfd32d4

File tree

6 files changed

+37
-7
lines changed

6 files changed

+37
-7
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def post(self, *args, **kwargs):
130130
memory_in_gbs = input_data.get("memory_in_gbs")
131131
model_file = input_data.get("model_file")
132132
private_endpoint_id = input_data.get("private_endpoint_id")
133+
container_image_uri = input_data.get("container_image_uri")
134+
cmd_var = input_data.get("cmd_var")
133135

134136
self.finish(
135137
AquaDeploymentApp().create(
@@ -153,6 +155,8 @@ def post(self, *args, **kwargs):
153155
memory_in_gbs=memory_in_gbs,
154156
model_file=model_file,
155157
private_endpoint_id=private_endpoint_id,
158+
container_image_uri=container_image_uri,
159+
cmd_var=cmd_var,
156160
)
157161
)
158162

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def post(self, *args, **kwargs):
128128
download_from_hf = (
129129
str(input_data.get("download_from_hf", "false")).lower() == "true"
130130
)
131+
inference_container_uri = input_data.get("inference_container_uri")
131132

132133
return self.finish(
133134
AquaModelApp().register(
@@ -139,6 +140,7 @@ def post(self, *args, **kwargs):
139140
compartment_id=compartment_id,
140141
project_id=project_id,
141142
model_file=model_file,
143+
inference_container_uri=inference_container_uri,
142144
)
143145
)
144146

ads/dataset/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _repr_html_(self):
202202
self.sampled_df.head(5)
203203
.style.set_table_styles(utils.get_dataframe_styles())
204204
.set_table_attributes("class=table")
205-
.hide_index()
205+
.hide()
206206
.to_html()
207207
)
208208
)
@@ -261,7 +261,7 @@ def _repr_html_(self):
261261
utils.horizontal_scrollable_div(
262262
self.style.set_table_styles(utils.get_dataframe_styles())
263263
.set_table_attributes("class=table")
264-
.hide_index()
264+
.hide()
265265
.to_html()
266266
)
267267
)

ads/dataset/factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def list_snapshots(snapshot_dir=None, name="", storage_options=None, **kwargs):
366366
display(
367367
HTML(
368368
list_df.style.set_table_attributes("class=table")
369-
.hide_index()
369+
.hide()
370370
.to_html()
371371
)
372372
)

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def test_post(self, mock_create):
154154
ocpus=None,
155155
model_file=None,
156156
private_endpoint_id=None,
157+
container_image_uri=None,
158+
cmd_var=None,
157159
)
158160

159161

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from huggingface_hub.hf_api import HfApi, ModelInfo
1111
from huggingface_hub.utils import GatedRepoError
1212
from notebook.base.handlers import IPythonHandler
13+
from parameterized import parameterized
1314

1415
from ads.aqua.common.errors import AquaRuntimeError
1516
from ads.aqua.common.utils import get_hf_model_info
@@ -129,9 +130,25 @@ def test_list(self, mock_list):
129130
compartment_id=None, project_id=None, model_type=None
130131
)
131132

133+
@parameterized.expand(
134+
[
135+
(None, None, False, None),
136+
("odsc-llm-fine-tuning", None, False, None),
137+
(None, "test.gguf", True, None),
138+
(None, None, True, "iad.ocir.io/<namespace>/<image>:<tag>"),
139+
],
140+
)
132141
@patch("notebook.base.handlers.APIHandler.finish")
133142
@patch("ads.aqua.model.AquaModelApp.register")
134-
def test_register(self, mock_register, mock_finish):
143+
def test_register(
144+
self,
145+
finetuning_container,
146+
model_file,
147+
download_from_hf,
148+
inference_container_uri,
149+
mock_register,
150+
mock_finish,
151+
):
135152
mock_register.return_value = AquaModel(
136153
id="test_id",
137154
inference_container="odsc-tgi-serving",
@@ -144,18 +161,23 @@ def test_register(self, mock_register, mock_finish):
144161
model="test_model_name",
145162
os_path="test_os_path",
146163
inference_container="odsc-tgi-serving",
164+
finetuning_container=finetuning_container,
165+
model_file=model_file,
166+
download_from_hf=download_from_hf,
167+
inference_container_uri=inference_container_uri,
147168
)
148169
)
149170
result = self.model_handler.post()
150171
mock_register.assert_called_with(
151172
model="test_model_name",
152173
os_path="test_os_path",
153174
inference_container="odsc-tgi-serving",
154-
finetuning_container=None,
175+
finetuning_container=finetuning_container,
155176
compartment_id=None,
156177
project_id=None,
157-
model_file=None,
158-
download_from_hf=False,
178+
model_file=model_file,
179+
download_from_hf=download_from_hf,
180+
inference_container_uri=inference_container_uri,
159181
)
160182
assert result["id"] == "test_id"
161183
assert result["inference_container"] == "odsc-tgi-serving"

0 commit comments

Comments
 (0)