Skip to content

Commit bee6a92

Browse files
committed
Add tests for chat model.
1 parent 3d9ecfd commit bee6a92

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Test OCI Data Science Model Deployment Endpoint."""
2+
3+
from unittest import mock
4+
import pytest
5+
from langchain_core.messages import AIMessage, AIMessageChunk
6+
from requests.exceptions import HTTPError
7+
from ads.llm import ChatOCIModelDeploymentVLLM, ChatOCIModelDeploymentTGI
8+
9+
10+
CONST_MODEL_NAME = "odsc-vllm"
11+
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
12+
CONST_PROMPT = "This is a prompt."
13+
CONST_COMPLETION = "This is a completion."
14+
CONST_COMPLETION_RESPONSE = {
15+
"id": "chat-123456789",
16+
"object": "chat.completion",
17+
"created": 123456789,
18+
"model": "mistral",
19+
"choices": [
20+
{
21+
"index": 0,
22+
"message": {
23+
"role": "assistant",
24+
"content": CONST_COMPLETION,
25+
"tool_calls": [],
26+
},
27+
"logprobs": None,
28+
"finish_reason": "length",
29+
"stop_reason": None,
30+
}
31+
],
32+
"usage": {"prompt_tokens": 115, "total_tokens": 371, "completion_tokens": 256},
33+
"prompt_logprobs": None,
34+
}
35+
CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION}
36+
CONST_STREAM_TEMPLATE = (
37+
'data: {"id":"chat-123456","object":"chat.completion.chunk","created":123456789,'
38+
'"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}'
39+
)
40+
CONST_STREAM_DELTAS = ['{"role":"assistant","content":""}'] + [
41+
'{"content":" ' + word + '"}' for word in CONST_COMPLETION.split(" ")
42+
]
43+
CONST_STREAM_RESPONSE = (
44+
content
45+
for content in [
46+
CONST_STREAM_TEMPLATE.replace("<DELTA>", delta).encode()
47+
for delta in CONST_STREAM_DELTAS
48+
]
49+
+ [b"data: [DONE]"]
50+
)
51+
52+
CONST_ASYNC_STREAM_TEMPLATE = (
53+
'{"id":"chat-123456","object":"chat.completion.chunk","created":123456789,'
54+
'"model":"odsc-llm","choices":[{"index":0,"delta":<DELTA>,"finish_reason":null}]}'
55+
)
56+
CONST_ASYNC_STREAM_RESPONSE = (
57+
CONST_ASYNC_STREAM_TEMPLATE.replace("<DELTA>", delta).encode()
58+
for delta in CONST_STREAM_DELTAS
59+
)
60+
61+
62+
def mocked_requests_post(self, **kwargs):
63+
"""Method to mock post requests"""
64+
65+
class MockResponse:
66+
"""Represents a mocked response."""
67+
68+
def __init__(self, json_data, status_code=200):
69+
self.json_data = json_data
70+
self.status_code = status_code
71+
72+
def raise_for_status(self):
73+
"""Mocked raise for status."""
74+
if 400 <= self.status_code < 600:
75+
raise HTTPError("", response=self)
76+
77+
def json(self):
78+
"""Returns mocked json data."""
79+
return self.json_data
80+
81+
def iter_lines(self, chunk_size=4096):
82+
"""Returns a generator of mocked streaming response."""
83+
return CONST_STREAM_RESPONSE
84+
85+
@property
86+
def text(self):
87+
return ""
88+
89+
payload = kwargs.get("json")
90+
messages = payload.get("messages")
91+
prompt = messages[0].get("content")
92+
93+
if prompt == CONST_PROMPT:
94+
return MockResponse(json_data=CONST_COMPLETION_RESPONSE)
95+
96+
return MockResponse(
97+
json_data={},
98+
status_code=404,
99+
)
100+
101+
102+
@pytest.mark.requires("ads")
103+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
104+
@mock.patch("requests.post", side_effect=mocked_requests_post)
105+
def test_invoke_vllm(mock_post, mock_auth) -> None:
106+
"""Tests invoking vLLM endpoint."""
107+
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
108+
output = llm.invoke(CONST_PROMPT)
109+
assert isinstance(output, AIMessage)
110+
assert output.content == CONST_COMPLETION
111+
112+
113+
@pytest.mark.requires("ads")
114+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
115+
@mock.patch("requests.post", side_effect=mocked_requests_post)
116+
def test_invoke_tgi(mock_post, mock_auth) -> None:
117+
"""Tests invoking TGI endpoint using OpenAI Spec."""
118+
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
119+
output = llm.invoke(CONST_PROMPT)
120+
assert isinstance(output, AIMessage)
121+
assert output.content == CONST_COMPLETION
122+
123+
124+
@pytest.mark.requires("ads")
125+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
126+
@mock.patch("requests.post", side_effect=mocked_requests_post)
127+
def test_stream_vllm(mock_post, mock_auth) -> None:
128+
"""Tests streaming with vLLM endpoint using OpenAI spec."""
129+
llm = ChatOCIModelDeploymentVLLM(
130+
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
131+
)
132+
output = AIMessageChunk("")
133+
count = 0
134+
for chunk in llm.stream(CONST_PROMPT):
135+
assert isinstance(chunk, AIMessageChunk)
136+
output += chunk
137+
count += 1
138+
assert count == 5
139+
assert output.content.strip() == CONST_COMPLETION
140+
141+
142+
async def mocked_async_streaming_response(*args, **kwargs):
143+
"""Returns mocked response for async streaming."""
144+
for item in CONST_ASYNC_STREAM_RESPONSE:
145+
yield item
146+
147+
148+
@pytest.mark.asyncio
149+
@pytest.mark.requires("ads")
150+
@mock.patch(
151+
"ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock())
152+
)
153+
@mock.patch(
154+
"langchain_community.utilities.requests.Requests.apost",
155+
mock.MagicMock(),
156+
)
157+
async def test_stream_async(mock_auth):
158+
"""Tests async streaming."""
159+
llm = ChatOCIModelDeploymentVLLM(
160+
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
161+
)
162+
with mock.patch.object(
163+
llm,
164+
"_aiter_sse",
165+
mock.MagicMock(return_value=mocked_async_streaming_response()),
166+
):
167+
168+
chunks = [chunk.content async for chunk in llm.astream(CONST_PROMPT)]
169+
assert "".join(chunks).strip() == CONST_COMPLETION

0 commit comments

Comments
 (0)