Skip to content

Commit 3383bca

Browse files
Merge branch 'main' into feature/forecasting-model-deployments
2 parents 759e6c3 + d55e593 commit 3383bca

File tree

3 files changed

+149
-30
lines changed

3 files changed

+149
-30
lines changed

ads/aqua/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
SUPPORTED_FILE_FORMATS = ["jsonl"]
5656
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
5757

58+
AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59+
5860
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
5961
"datasciencemodel": "models",
6062
"datasciencemodeldeployment": "model-deployments",

ads/aqua/extension/model_handler.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
from ads.aqua.common.enums import CustomInferenceContainerTypeFamily
1212
from ads.aqua.common.errors import AquaRuntimeError
1313
from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models
14+
from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY
1415
from ads.aqua.extension.base_handler import AquaAPIhandler
1516
from ads.aqua.extension.errors import Errors
1617
from ads.aqua.model import AquaModelApp
1718
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1819
from ads.config import SERVICE
20+
from ads.model import DataScienceModel
1921
from ads.model.common.utils import MetadataArtifactPathType
22+
from ads.model.service.oci_datascience_model import OCIDataScienceModel
2023

2124

2225
class AquaModelHandler(AquaAPIhandler):
@@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002
320323
)
321324

322325

323-
class AquaModelTokenizerConfigHandler(AquaAPIhandler):
326+
class AquaModelChatTemplateHandler(AquaAPIhandler):
324327
def get(self, model_id):
325328
"""
326-
Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model.
327-
Expected request format: GET /aqua/models/<model-ocid>/tokenizer
329+
Handles requests for retrieving the chat template from custom metadata of a specified model.
330+
Expected request format: GET /aqua/models/<model-ocid>/chat-template
328331
329332
"""
330333

331334
path_list = urlparse(self.request.path).path.strip("/").split("/")
332-
# Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer
333-
# path_list=['aqua','models','<model-ocid>','tokenizer']
335+
# Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template
336+
# path_list=['aqua','models','<model-ocid>','chat-template']
334337
if (
335338
len(path_list) == 4
336339
and is_valid_ocid(path_list[2])
337-
and path_list[3] == "tokenizer"
340+
and path_list[3] == "chat-template"
338341
):
339-
return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id))
342+
try:
343+
oci_data_science_model = OCIDataScienceModel.from_id(model_id)
344+
except Exception as e:
345+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
346+
return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template"))
340347

341348
raise HTTPError(400, f"The request {self.request.path} is invalid.")
342349

350+
@handle_exceptions
351+
def post(self, model_id: str):
352+
"""
353+
Handles POST requests to add a custom chat_template metadata artifact to a model.
354+
355+
Expected request format:
356+
POST /aqua/models/<model-ocid>/chat-template
357+
Body: { "chat_template": "<your_template_string>" }
358+
359+
"""
360+
try:
361+
input_body = self.get_json_body()
362+
except Exception as e:
363+
raise HTTPError(400, f"Invalid JSON body: {str(e)}")
364+
365+
chat_template = input_body.get("chat_template")
366+
if not chat_template:
367+
raise HTTPError(400, "Missing required field: 'chat_template'")
368+
369+
try:
370+
data_science_model = DataScienceModel.from_id(model_id)
371+
except Exception as e:
372+
raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}")
373+
374+
try:
375+
result = data_science_model.create_custom_metadata_artifact(
376+
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
377+
path_type=MetadataArtifactPathType.CONTENT,
378+
artifact_path_or_content=chat_template.encode()
379+
)
380+
except Exception as e:
381+
raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}")
382+
383+
return self.finish(result)
384+
343385

344386
class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler):
345387
"""
@@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str):
381423
("model/?([^/]*)", AquaModelHandler),
382424
("model/?([^/]*)/license", AquaModelLicenseHandler),
383425
("model/?([^/]*)/readme", AquaModelReadmeHandler),
384-
("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler),
426+
("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler),
385427
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
386428
(
387429
"model/?([^/]*)/definedMetadata/?([^/]*)",

tests/unitary/with_extras/aqua/test_model_handler.py

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
from unicodedata import category
66
from unittest import TestCase
7-
from unittest.mock import MagicMock, patch
7+
from unittest.mock import MagicMock, patch, ANY
88

99
import pytest
1010
from huggingface_hub.hf_api import HfApi, ModelInfo
@@ -14,13 +14,13 @@
1414

1515
from ads.aqua.common.errors import AquaRuntimeError
1616
from ads.aqua.common.utils import get_hf_model_info
17-
from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES
17+
from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES, AQUA_CHAT_TEMPLATE_METADATA_KEY
1818
from ads.aqua.extension.errors import ReplyDetails
1919
from ads.aqua.extension.model_handler import (
2020
AquaHuggingFaceHandler,
2121
AquaModelHandler,
2222
AquaModelLicenseHandler,
23-
AquaModelTokenizerConfigHandler,
23+
AquaModelChatTemplateHandler
2424
)
2525
from ads.aqua.model import AquaModelApp
2626
from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
@@ -254,39 +254,114 @@ def test_get(self, mock_load_license):
254254
mock_load_license.assert_called_with("test_model_id")
255255

256256

257-
class ModelTokenizerConfigHandlerTestCase(TestCase):
257+
class AquaModelChatTemplateHandlerTestCase(TestCase):
258258
@patch.object(IPythonHandler, "__init__")
259259
def setUp(self, ipython_init_mock) -> None:
260260
ipython_init_mock.return_value = None
261-
self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler(
261+
self.model_chat_template_handler = AquaModelChatTemplateHandler(
262262
MagicMock(), MagicMock()
263263
)
264-
self.model_tokenizer_config_handler.finish = MagicMock()
265-
self.model_tokenizer_config_handler.request = MagicMock()
264+
self.model_chat_template_handler.finish = MagicMock()
265+
self.model_chat_template_handler.request = MagicMock()
266+
self.model_chat_template_handler._headers = {}
266267

267-
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
268+
@patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id")
268269
@patch("ads.aqua.extension.model_handler.urlparse")
269-
def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config):
270-
request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer")
270+
def test_get_valid_path(self, mock_urlparse, mock_from_id):
271+
request_path = MagicMock(path="/aqua/models/ocid1.xx./chat-template")
271272
mock_urlparse.return_value = request_path
272-
self.model_tokenizer_config_handler.get(model_id="test_model_id")
273-
self.model_tokenizer_config_handler.finish.assert_called_with(
274-
mock_get_hf_tokenizer_config.return_value
275-
)
276-
mock_get_hf_tokenizer_config.assert_called_with("test_model_id")
277273

278-
@patch.object(AquaModelApp, "get_hf_tokenizer_config")
274+
model_mock = MagicMock()
275+
model_mock.get_custom_metadata_artifact.return_value = "chat_template_string"
276+
mock_from_id.return_value = model_mock
277+
278+
self.model_chat_template_handler.get(model_id="test_model_id")
279+
self.model_chat_template_handler.finish.assert_called_with("chat_template_string")
280+
model_mock.get_custom_metadata_artifact.assert_called_with("chat_template")
281+
279282
@patch("ads.aqua.extension.model_handler.urlparse")
280-
def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config):
281-
"""Test invalid request path should raise HTTPError(400)"""
282-
request_path = MagicMock(path="/invalid/path")
283+
def test_get_invalid_path(self, mock_urlparse):
284+
request_path = MagicMock(path="/wrong/path")
283285
mock_urlparse.return_value = request_path
284286

285287
with self.assertRaises(HTTPError) as context:
286-
self.model_tokenizer_config_handler.get(model_id="test_model_id")
288+
self.model_chat_template_handler.get("ocid1.test.chat")
287289
self.assertEqual(context.exception.status_code, 400)
288-
self.model_tokenizer_config_handler.finish.assert_not_called()
289-
mock_get_hf_tokenizer_config.assert_not_called()
290+
291+
@patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id", side_effect=Exception("Not found"))
292+
@patch("ads.aqua.extension.model_handler.urlparse")
293+
def test_get_model_not_found(self, mock_urlparse, mock_from_id):
294+
request_path = MagicMock(path="/aqua/models/ocid1.invalid/chat-template")
295+
mock_urlparse.return_value = request_path
296+
297+
with self.assertRaises(HTTPError) as context:
298+
self.model_chat_template_handler.get("ocid1.invalid")
299+
self.assertEqual(context.exception.status_code, 404)
300+
301+
@patch("ads.aqua.extension.model_handler.DataScienceModel.from_id")
302+
def test_post_valid(self, mock_from_id):
303+
model_mock = MagicMock()
304+
model_mock.create_custom_metadata_artifact.return_value = {"result": "success"}
305+
mock_from_id.return_value = model_mock
306+
307+
self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "Hello <|user|>"})
308+
result = self.model_chat_template_handler.post("ocid1.valid")
309+
self.model_chat_template_handler.finish.assert_called_with({"result": "success"})
310+
311+
model_mock.create_custom_metadata_artifact.assert_called_with(
312+
metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY,
313+
path_type=ANY,
314+
artifact_path_or_content=b"Hello <|user|>"
315+
)
316+
317+
@patch.object(AquaModelChatTemplateHandler, "write_error")
318+
def test_post_invalid_json(self, mock_write_error):
319+
self.model_chat_template_handler.get_json_body = MagicMock(side_effect=Exception("Invalid JSON"))
320+
self.model_chat_template_handler._headers = {}
321+
self.model_chat_template_handler.post("ocid1.test.invalidjson")
322+
323+
mock_write_error.assert_called_once()
324+
325+
kwargs = mock_write_error.call_args.kwargs
326+
exc_info = kwargs.get("exc_info")
327+
328+
assert exc_info is not None
329+
exc_type, exc_instance, _ = exc_info
330+
331+
assert isinstance(exc_instance, HTTPError)
332+
assert exc_instance.status_code == 400
333+
assert "Invalid JSON body" in str(exc_instance)
334+
335+
@patch.object(AquaModelChatTemplateHandler, "write_error")
336+
def test_post_missing_chat_template(self, mock_write_error):
337+
self.model_chat_template_handler.get_json_body = MagicMock(return_value={})
338+
self.model_chat_template_handler._headers = {}
339+
340+
self.model_chat_template_handler.post("ocid1.test.model")
341+
342+
mock_write_error.assert_called_once()
343+
exc_info = mock_write_error.call_args.kwargs.get("exc_info")
344+
assert exc_info is not None
345+
_, exc_instance, _ = exc_info
346+
assert isinstance(exc_instance, HTTPError)
347+
assert exc_instance.status_code == 400
348+
assert "Missing required field: 'chat_template'" in str(exc_instance)
349+
350+
@patch("ads.aqua.extension.model_handler.DataScienceModel.from_id", side_effect=Exception("Not found"))
351+
@patch.object(AquaModelChatTemplateHandler, "write_error")
352+
def test_post_model_not_found(self, mock_write_error, mock_from_id):
353+
self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "test template"})
354+
self.model_chat_template_handler._headers = {}
355+
356+
self.model_chat_template_handler.post("ocid1.invalid.model")
357+
358+
mock_write_error.assert_called_once()
359+
exc_info = mock_write_error.call_args.kwargs.get("exc_info")
360+
assert exc_info is not None
361+
_, exc_instance, _ = exc_info
362+
assert isinstance(exc_instance, HTTPError)
363+
assert exc_instance.status_code == 404
364+
assert "Model not found" in str(exc_instance)
290365

291366

292367
class TestAquaHuggingFaceHandler:

0 commit comments

Comments
 (0)