Skip to content

Commit ce6d4e5

Browse files
committed
Update test for LangChain LLMs.
1 parent 845c5e4 commit ce6d4e5

File tree

2 files changed

+148
-65
lines changed

2 files changed

+148
-65
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""Test OCI Data Science Model Deployment Endpoint."""
2+
3+
from unittest import mock
4+
import pytest
5+
from requests.exceptions import HTTPError
6+
from ads.llm import OCIModelDeploymentTGI, OCIModelDeploymentVLLM
7+
8+
9+
CONST_MODEL_NAME = "odsc-vllm"
10+
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
11+
CONST_PROMPT_FOR_COMPLETION = "This is a prompt."
12+
CONST_COMPLETION = "This is a completion."
13+
CONST_COMPLETION_RESPONSE = {
14+
"choices": [{"index": 0, "text": CONST_COMPLETION}],
15+
}
16+
CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION}
17+
CONST_STREAM_TEMPLATE = (
18+
'data: {"id":"","object":"text_completion","created":123456,'
19+
+ '"choices":[{"index":0,"text":"<TOKEN>","finish_reason":""}]}'
20+
)
21+
CONST_STREAM_RESPONSE = (
22+
CONST_STREAM_TEMPLATE.replace("<TOKEN>", " " + word).encode()
23+
for word in CONST_COMPLETION.split(" ")
24+
)
25+
26+
CONST_ASYNC_STREAM_TEMPLATE = (
27+
'{"id":"","object":"text_completion","created":123456,'
28+
+ '"choices":[{"index":0,"text":"<TOKEN>","finish_reason":""}]}'
29+
)
30+
CONST_ASYNC_STREAM_RESPONSE = (
31+
CONST_ASYNC_STREAM_TEMPLATE.replace("<TOKEN>", " " + word).encode()
32+
for word in CONST_COMPLETION.split(" ")
33+
)
34+
35+
36+
def mocked_requests_post(self, **kwargs):
37+
"""Method to mock post requests"""
38+
39+
class MockResponse:
40+
"""Represents a mocked response."""
41+
42+
def __init__(self, json_data, status_code=200):
43+
self.json_data = json_data
44+
self.status_code = status_code
45+
46+
def raise_for_status(self):
47+
"""Mocked raise for status."""
48+
if 400 <= self.status_code < 600:
49+
raise HTTPError("", response=self)
50+
51+
def json(self):
52+
"""Returns mocked json data."""
53+
return self.json_data
54+
55+
def iter_lines(self, chunk_size=4096):
56+
"""Returns a generator of mocked streaming response."""
57+
return CONST_STREAM_RESPONSE
58+
59+
@property
60+
def text(self):
61+
return ""
62+
63+
payload = kwargs.get("json")
64+
if "inputs" in payload:
65+
prompt = payload.get("inputs")
66+
is_tgi = True
67+
else:
68+
prompt = payload.get("prompt")
69+
is_tgi = False
70+
71+
if prompt == CONST_PROMPT_FOR_COMPLETION:
72+
if is_tgi:
73+
return MockResponse(json_data=CONST_COMPLETION_RESPONSE_TGI)
74+
return MockResponse(json_data=CONST_COMPLETION_RESPONSE)
75+
76+
return MockResponse(
77+
json_data={},
78+
status_code=404,
79+
)
80+
81+
82+
async def mocked_async_streaming_response(*args, **kwargs):
83+
"""Returns mocked response for async streaming."""
84+
for item in CONST_ASYNC_STREAM_RESPONSE:
85+
yield item
86+
87+
88+
@pytest.mark.requires("ads")
89+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
90+
@mock.patch("requests.post", side_effect=mocked_requests_post)
91+
def test_invoke_vllm(mock_post, mock_auth) -> None:
92+
"""Tests invoking vLLM endpoint."""
93+
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
94+
output = llm.invoke(CONST_PROMPT_FOR_COMPLETION)
95+
assert output == CONST_COMPLETION
96+
97+
98+
@pytest.mark.requires("ads")
99+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
100+
@mock.patch("requests.post", side_effect=mocked_requests_post)
101+
def test_stream_tgi(mock_post, mock_auth) -> None:
102+
"""Tests streaming with TGI endpoint using OpenAI spec."""
103+
llm = OCIModelDeploymentTGI(
104+
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
105+
)
106+
output = ""
107+
count = 0
108+
for chunk in llm.stream(CONST_PROMPT_FOR_COMPLETION):
109+
output += chunk
110+
count += 1
111+
assert count == 4
112+
assert output.strip() == CONST_COMPLETION
113+
114+
115+
@pytest.mark.requires("ads")
116+
@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
117+
@mock.patch("requests.post", side_effect=mocked_requests_post)
118+
def test_generate_tgi(mock_post, mock_auth) -> None:
119+
"""Tests invoking TGI endpoint using TGI generate spec."""
120+
llm = OCIModelDeploymentTGI(
121+
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
122+
)
123+
output = llm.invoke(CONST_PROMPT_FOR_COMPLETION)
124+
assert output == CONST_COMPLETION
125+
126+
127+
@pytest.mark.asyncio
128+
@pytest.mark.requires("ads")
129+
@mock.patch(
130+
"ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock())
131+
)
132+
@mock.patch(
133+
"langchain_community.utilities.requests.Requests.apost",
134+
mock.MagicMock(),
135+
)
136+
async def test_stream_async(mock_auth):
137+
"""Tests async streaming."""
138+
llm = OCIModelDeploymentTGI(
139+
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
140+
)
141+
with mock.patch.object(
142+
llm,
143+
"_aiter_sse",
144+
mock.MagicMock(return_value=mocked_async_streaming_response()),
145+
):
146+
147+
chunks = [chunk async for chunk in llm.astream(CONST_PROMPT_FOR_COMPLETION)]
148+
assert "".join(chunks).strip() == CONST_COMPLETION

tests/unitary/with_extras/langchain/test_llm_plugins.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

0 commit comments

Comments
 (0)