|
4 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5 | 5 | from unicodedata import category
|
6 | 6 | from unittest import TestCase
|
7 |
| -from unittest.mock import MagicMock, patch |
| 7 | +from unittest.mock import MagicMock, patch, ANY |
8 | 8 |
|
9 | 9 | import pytest
|
10 | 10 | from huggingface_hub.hf_api import HfApi, ModelInfo
|
|
14 | 14 |
|
15 | 15 | from ads.aqua.common.errors import AquaRuntimeError
|
16 | 16 | from ads.aqua.common.utils import get_hf_model_info
|
17 |
| -from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES |
| 17 | +from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES, AQUA_CHAT_TEMPLATE_METADATA_KEY |
18 | 18 | from ads.aqua.extension.errors import ReplyDetails
|
19 | 19 | from ads.aqua.extension.model_handler import (
|
20 | 20 | AquaHuggingFaceHandler,
|
21 | 21 | AquaModelHandler,
|
22 | 22 | AquaModelLicenseHandler,
|
23 |
| - AquaModelTokenizerConfigHandler, |
| 23 | + AquaModelChatTemplateHandler |
24 | 24 | )
|
25 | 25 | from ads.aqua.model import AquaModelApp
|
26 | 26 | from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary
|
@@ -254,39 +254,114 @@ def test_get(self, mock_load_license):
|
254 | 254 | mock_load_license.assert_called_with("test_model_id")
|
255 | 255 |
|
256 | 256 |
|
257 |
| -class ModelTokenizerConfigHandlerTestCase(TestCase): |
| 257 | +class AquaModelChatTemplateHandlerTestCase(TestCase): |
258 | 258 | @patch.object(IPythonHandler, "__init__")
|
259 | 259 | def setUp(self, ipython_init_mock) -> None:
|
260 | 260 | ipython_init_mock.return_value = None
|
261 |
| - self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler( |
| 261 | + self.model_chat_template_handler = AquaModelChatTemplateHandler( |
262 | 262 | MagicMock(), MagicMock()
|
263 | 263 | )
|
264 |
| - self.model_tokenizer_config_handler.finish = MagicMock() |
265 |
| - self.model_tokenizer_config_handler.request = MagicMock() |
| 264 | + self.model_chat_template_handler.finish = MagicMock() |
| 265 | + self.model_chat_template_handler.request = MagicMock() |
| 266 | + self.model_chat_template_handler._headers = {} |
266 | 267 |
|
267 |
| - @patch.object(AquaModelApp, "get_hf_tokenizer_config") |
| 268 | + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id") |
268 | 269 | @patch("ads.aqua.extension.model_handler.urlparse")
|
269 |
| - def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config): |
270 |
| - request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer") |
| 270 | + def test_get_valid_path(self, mock_urlparse, mock_from_id): |
| 271 | + request_path = MagicMock(path="/aqua/models/ocid1.xx./chat-template") |
271 | 272 | mock_urlparse.return_value = request_path
|
272 |
| - self.model_tokenizer_config_handler.get(model_id="test_model_id") |
273 |
| - self.model_tokenizer_config_handler.finish.assert_called_with( |
274 |
| - mock_get_hf_tokenizer_config.return_value |
275 |
| - ) |
276 |
| - mock_get_hf_tokenizer_config.assert_called_with("test_model_id") |
277 | 273 |
|
278 |
| - @patch.object(AquaModelApp, "get_hf_tokenizer_config") |
| 274 | + model_mock = MagicMock() |
| 275 | + model_mock.get_custom_metadata_artifact.return_value = "chat_template_string" |
| 276 | + mock_from_id.return_value = model_mock |
| 277 | + |
| 278 | + self.model_chat_template_handler.get(model_id="test_model_id") |
| 279 | + self.model_chat_template_handler.finish.assert_called_with("chat_template_string") |
| 280 | + model_mock.get_custom_metadata_artifact.assert_called_with("chat_template") |
| 281 | + |
279 | 282 | @patch("ads.aqua.extension.model_handler.urlparse")
|
280 |
| - def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config): |
281 |
| - """Test invalid request path should raise HTTPError(400)""" |
282 |
| - request_path = MagicMock(path="/invalid/path") |
| 283 | + def test_get_invalid_path(self, mock_urlparse): |
| 284 | + request_path = MagicMock(path="/wrong/path") |
283 | 285 | mock_urlparse.return_value = request_path
|
284 | 286 |
|
285 | 287 | with self.assertRaises(HTTPError) as context:
|
286 |
| - self.model_tokenizer_config_handler.get(model_id="test_model_id") |
| 288 | + self.model_chat_template_handler.get("ocid1.test.chat") |
287 | 289 | self.assertEqual(context.exception.status_code, 400)
|
288 |
| - self.model_tokenizer_config_handler.finish.assert_not_called() |
289 |
| - mock_get_hf_tokenizer_config.assert_not_called() |
| 290 | + |
| 291 | + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id", side_effect=Exception("Not found")) |
| 292 | + @patch("ads.aqua.extension.model_handler.urlparse") |
| 293 | + def test_get_model_not_found(self, mock_urlparse, mock_from_id): |
| 294 | + request_path = MagicMock(path="/aqua/models/ocid1.invalid/chat-template") |
| 295 | + mock_urlparse.return_value = request_path |
| 296 | + |
| 297 | + with self.assertRaises(HTTPError) as context: |
| 298 | + self.model_chat_template_handler.get("ocid1.invalid") |
| 299 | + self.assertEqual(context.exception.status_code, 404) |
| 300 | + |
| 301 | + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id") |
| 302 | + def test_post_valid(self, mock_from_id): |
| 303 | + model_mock = MagicMock() |
| 304 | + model_mock.create_custom_metadata_artifact.return_value = {"result": "success"} |
| 305 | + mock_from_id.return_value = model_mock |
| 306 | + |
| 307 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "Hello <|user|>"}) |
| 308 | + result = self.model_chat_template_handler.post("ocid1.valid") |
| 309 | + self.model_chat_template_handler.finish.assert_called_with({"result": "success"}) |
| 310 | + |
| 311 | + model_mock.create_custom_metadata_artifact.assert_called_with( |
| 312 | + metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY, |
| 313 | + path_type=ANY, |
| 314 | + artifact_path_or_content=b"Hello <|user|>" |
| 315 | + ) |
| 316 | + |
| 317 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 318 | + def test_post_invalid_json(self, mock_write_error): |
| 319 | + self.model_chat_template_handler.get_json_body = MagicMock(side_effect=Exception("Invalid JSON")) |
| 320 | + self.model_chat_template_handler._headers = {} |
| 321 | + self.model_chat_template_handler.post("ocid1.test.invalidjson") |
| 322 | + |
| 323 | + mock_write_error.assert_called_once() |
| 324 | + |
| 325 | + kwargs = mock_write_error.call_args.kwargs |
| 326 | + exc_info = kwargs.get("exc_info") |
| 327 | + |
| 328 | + assert exc_info is not None |
| 329 | + exc_type, exc_instance, _ = exc_info |
| 330 | + |
| 331 | + assert isinstance(exc_instance, HTTPError) |
| 332 | + assert exc_instance.status_code == 400 |
| 333 | + assert "Invalid JSON body" in str(exc_instance) |
| 334 | + |
| 335 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 336 | + def test_post_missing_chat_template(self, mock_write_error): |
| 337 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={}) |
| 338 | + self.model_chat_template_handler._headers = {} |
| 339 | + |
| 340 | + self.model_chat_template_handler.post("ocid1.test.model") |
| 341 | + |
| 342 | + mock_write_error.assert_called_once() |
| 343 | + exc_info = mock_write_error.call_args.kwargs.get("exc_info") |
| 344 | + assert exc_info is not None |
| 345 | + _, exc_instance, _ = exc_info |
| 346 | + assert isinstance(exc_instance, HTTPError) |
| 347 | + assert exc_instance.status_code == 400 |
| 348 | + assert "Missing required field: 'chat_template'" in str(exc_instance) |
| 349 | + |
| 350 | + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id", side_effect=Exception("Not found")) |
| 351 | + @patch.object(AquaModelChatTemplateHandler, "write_error") |
| 352 | + def test_post_model_not_found(self, mock_write_error, mock_from_id): |
| 353 | + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "test template"}) |
| 354 | + self.model_chat_template_handler._headers = {} |
| 355 | + |
| 356 | + self.model_chat_template_handler.post("ocid1.invalid.model") |
| 357 | + |
| 358 | + mock_write_error.assert_called_once() |
| 359 | + exc_info = mock_write_error.call_args.kwargs.get("exc_info") |
| 360 | + assert exc_info is not None |
| 361 | + _, exc_instance, _ = exc_info |
| 362 | + assert isinstance(exc_instance, HTTPError) |
| 363 | + assert exc_instance.status_code == 404 |
| 364 | + assert "Model not found" in str(exc_instance) |
290 | 365 |
|
291 | 366 |
|
292 | 367 | class TestAquaHuggingFaceHandler:
|
|
0 commit comments