|
| 1 | +"""Test OCI Data Science Model Deployment Endpoint.""" |
| 2 | + |
| 3 | +from unittest import mock |
| 4 | +import pytest |
| 5 | +from requests.exceptions import HTTPError |
| 6 | +from ads.llm import OCIModelDeploymentTGI, OCIModelDeploymentVLLM |
| 7 | + |
| 8 | + |
| 9 | +CONST_MODEL_NAME = "odsc-vllm" |
| 10 | +CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" |
| 11 | +CONST_PROMPT_FOR_COMPLETION = "This is a prompt." |
| 12 | +CONST_COMPLETION = "This is a completion." |
| 13 | +CONST_COMPLETION_RESPONSE = { |
| 14 | + "choices": [{"index": 0, "text": CONST_COMPLETION}], |
| 15 | +} |
| 16 | +CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION} |
| 17 | +CONST_STREAM_TEMPLATE = ( |
| 18 | + 'data: {"id":"","object":"text_completion","created":123456,' |
| 19 | + + '"choices":[{"index":0,"text":"<TOKEN>","finish_reason":""}]}' |
| 20 | +) |
| 21 | +CONST_STREAM_RESPONSE = ( |
| 22 | + CONST_STREAM_TEMPLATE.replace("<TOKEN>", " " + word).encode() |
| 23 | + for word in CONST_COMPLETION.split(" ") |
| 24 | +) |
| 25 | + |
| 26 | +CONST_ASYNC_STREAM_TEMPLATE = ( |
| 27 | + '{"id":"","object":"text_completion","created":123456,' |
| 28 | + + '"choices":[{"index":0,"text":"<TOKEN>","finish_reason":""}]}' |
| 29 | +) |
| 30 | +CONST_ASYNC_STREAM_RESPONSE = ( |
| 31 | + CONST_ASYNC_STREAM_TEMPLATE.replace("<TOKEN>", " " + word).encode() |
| 32 | + for word in CONST_COMPLETION.split(" ") |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +def mocked_requests_post(self, **kwargs): |
| 37 | + """Method to mock post requests""" |
| 38 | + |
| 39 | + class MockResponse: |
| 40 | + """Represents a mocked response.""" |
| 41 | + |
| 42 | + def __init__(self, json_data, status_code=200): |
| 43 | + self.json_data = json_data |
| 44 | + self.status_code = status_code |
| 45 | + |
| 46 | + def raise_for_status(self): |
| 47 | + """Mocked raise for status.""" |
| 48 | + if 400 <= self.status_code < 600: |
| 49 | + raise HTTPError("", response=self) |
| 50 | + |
| 51 | + def json(self): |
| 52 | + """Returns mocked json data.""" |
| 53 | + return self.json_data |
| 54 | + |
| 55 | + def iter_lines(self, chunk_size=4096): |
| 56 | + """Returns a generator of mocked streaming response.""" |
| 57 | + return CONST_STREAM_RESPONSE |
| 58 | + |
| 59 | + @property |
| 60 | + def text(self): |
| 61 | + return "" |
| 62 | + |
| 63 | + payload = kwargs.get("json") |
| 64 | + if "inputs" in payload: |
| 65 | + prompt = payload.get("inputs") |
| 66 | + is_tgi = True |
| 67 | + else: |
| 68 | + prompt = payload.get("prompt") |
| 69 | + is_tgi = False |
| 70 | + |
| 71 | + if prompt == CONST_PROMPT_FOR_COMPLETION: |
| 72 | + if is_tgi: |
| 73 | + return MockResponse(json_data=CONST_COMPLETION_RESPONSE_TGI) |
| 74 | + return MockResponse(json_data=CONST_COMPLETION_RESPONSE) |
| 75 | + |
| 76 | + return MockResponse( |
| 77 | + json_data={}, |
| 78 | + status_code=404, |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +async def mocked_async_streaming_response(*args, **kwargs): |
| 83 | + """Returns mocked response for async streaming.""" |
| 84 | + for item in CONST_ASYNC_STREAM_RESPONSE: |
| 85 | + yield item |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.requires("ads") |
| 89 | +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
| 90 | +@mock.patch("requests.post", side_effect=mocked_requests_post) |
| 91 | +def test_invoke_vllm(mock_post, mock_auth) -> None: |
| 92 | + """Tests invoking vLLM endpoint.""" |
| 93 | + llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) |
| 94 | + output = llm.invoke(CONST_PROMPT_FOR_COMPLETION) |
| 95 | + assert output == CONST_COMPLETION |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.requires("ads") |
| 99 | +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
| 100 | +@mock.patch("requests.post", side_effect=mocked_requests_post) |
| 101 | +def test_stream_tgi(mock_post, mock_auth) -> None: |
| 102 | + """Tests streaming with TGI endpoint using OpenAI spec.""" |
| 103 | + llm = OCIModelDeploymentTGI( |
| 104 | + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True |
| 105 | + ) |
| 106 | + output = "" |
| 107 | + count = 0 |
| 108 | + for chunk in llm.stream(CONST_PROMPT_FOR_COMPLETION): |
| 109 | + output += chunk |
| 110 | + count += 1 |
| 111 | + assert count == 4 |
| 112 | + assert output.strip() == CONST_COMPLETION |
| 113 | + |
| 114 | + |
| 115 | +@pytest.mark.requires("ads") |
| 116 | +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
| 117 | +@mock.patch("requests.post", side_effect=mocked_requests_post) |
| 118 | +def test_generate_tgi(mock_post, mock_auth) -> None: |
| 119 | + """Tests invoking TGI endpoint using TGI generate spec.""" |
| 120 | + llm = OCIModelDeploymentTGI( |
| 121 | + endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME |
| 122 | + ) |
| 123 | + output = llm.invoke(CONST_PROMPT_FOR_COMPLETION) |
| 124 | + assert output == CONST_COMPLETION |
| 125 | + |
| 126 | + |
| 127 | +@pytest.mark.asyncio |
| 128 | +@pytest.mark.requires("ads") |
| 129 | +@mock.patch( |
| 130 | + "ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock()) |
| 131 | +) |
| 132 | +@mock.patch( |
| 133 | + "langchain_community.utilities.requests.Requests.apost", |
| 134 | + mock.MagicMock(), |
| 135 | +) |
| 136 | +async def test_stream_async(mock_auth): |
| 137 | + """Tests async streaming.""" |
| 138 | + llm = OCIModelDeploymentTGI( |
| 139 | + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True |
| 140 | + ) |
| 141 | + with mock.patch.object( |
| 142 | + llm, |
| 143 | + "_aiter_sse", |
| 144 | + mock.MagicMock(return_value=mocked_async_streaming_response()), |
| 145 | + ): |
| 146 | + |
| 147 | + chunks = [chunk async for chunk in llm.astream(CONST_PROMPT_FOR_COMPLETION)] |
| 148 | + assert "".join(chunks).strip() == CONST_COMPLETION |
0 commit comments