Skip to content

Commit 610a418

Browse files
committed
add chat tests
1 parent 86b7986 commit 610a418

File tree

2 files changed

+403
-2
lines changed

2 files changed

+403
-2
lines changed

tests/conftest.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import pytest_asyncio
99
from fastapi.testclient import TestClient
1010
from openai.types import CreateEmbeddingResponse, Embedding
11+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
12+
from openai.types.chat.chat_completion import (
13+
ChatCompletionMessage,
14+
Choice,
15+
)
1116
from openai.types.create_embedding_response import Usage
1217
from sqlalchemy.ext.asyncio import async_sessionmaker
1318

@@ -107,10 +112,122 @@ def patch():
107112
return patch
108113

109114

115+
@pytest.fixture
116+
def mock_openai_chatcompletion(monkeypatch):
117+
class AsyncChatCompletionIterator:
118+
def __init__(self, answer: str):
119+
chunk_id = "test-id"
120+
model = "gpt-35-turbo"
121+
self.responses = [
122+
{"object": "chat.completion.chunk", "choices": [], "id": chunk_id, "model": model, "created": 1},
123+
{
124+
"object": "chat.completion.chunk",
125+
"choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}],
126+
"id": chunk_id,
127+
"model": model,
128+
"created": 1,
129+
},
130+
]
131+
# Split at << to simulate chunked responses
132+
if answer.find("<<") > -1:
133+
parts = answer.split("<<")
134+
self.responses.append(
135+
{
136+
"object": "chat.completion.chunk",
137+
"choices": [
138+
{
139+
"delta": {"role": "assistant", "content": parts[0] + "<<"},
140+
"index": 0,
141+
"finish_reason": None,
142+
}
143+
],
144+
"id": chunk_id,
145+
"model": model,
146+
"created": 1,
147+
}
148+
)
149+
self.responses.append(
150+
{
151+
"object": "chat.completion.chunk",
152+
"choices": [
153+
{"delta": {"role": "assistant", "content": parts[1]}, "index": 0, "finish_reason": None}
154+
],
155+
"id": chunk_id,
156+
"model": model,
157+
"created": 1,
158+
}
159+
)
160+
self.responses.append(
161+
{
162+
"object": "chat.completion.chunk",
163+
"choices": [{"delta": {"role": None, "content": None}, "index": 0, "finish_reason": "stop"}],
164+
"id": chunk_id,
165+
"model": model,
166+
"created": 1,
167+
}
168+
)
169+
else:
170+
self.responses.append(
171+
{
172+
"object": "chat.completion.chunk",
173+
"choices": [{"delta": {"content": answer}, "index": 0, "finish_reason": None}],
174+
"id": chunk_id,
175+
"model": model,
176+
"created": 1,
177+
}
178+
)
179+
180+
def __aiter__(self):
181+
return self
182+
183+
async def __anext__(self):
184+
if self.responses:
185+
return ChatCompletionChunk.model_validate(self.responses.pop(0))
186+
else:
187+
raise StopAsyncIteration
188+
189+
async def mock_acreate(*args, **kwargs):
190+
messages = kwargs["messages"]
191+
last_question = messages[-1]["content"]
192+
if last_question == "Generate search query for: What is the capital of France?":
193+
answer = "capital of France"
194+
elif last_question == "Generate search query for: Are interest rates high?":
195+
answer = "interest rates"
196+
elif isinstance(last_question, list) and last_question[2].get("image_url"):
197+
answer = "From the provided sources, the impact of interest rates and GDP growth on "
198+
"financial markets can be observed through the line graph. [Financial Market Analysis Report 2023-7.png]"
199+
else:
200+
answer = "The capital of France is Paris. [Benefit_Options-2.pdf]."
201+
if messages[0]["content"].find("Generate 3 very brief follow-up questions") > -1:
202+
answer = "The capital of France is Paris. [Benefit_Options-2.pdf]. <<What is the capital of Spain?>>"
203+
if "stream" in kwargs and kwargs["stream"] is True:
204+
return AsyncChatCompletionIterator(answer)
205+
else:
206+
return ChatCompletion(
207+
object="chat.completion",
208+
choices=[
209+
Choice(
210+
message=ChatCompletionMessage(role="assistant", content=answer), finish_reason="stop", index=0
211+
)
212+
],
213+
id="test-123",
214+
created=0,
215+
model="test-model",
216+
)
217+
218+
def patch():
219+
monkeypatch.setattr(openai.resources.chat.completions.AsyncCompletions, "create", mock_acreate)
220+
221+
return patch
222+
223+
110224
@pytest_asyncio.fixture(scope="function")
111-
async def test_client(monkeypatch, app, mock_default_azure_credential, mock_openai_embedding):
225+
async def test_client(
226+
monkeypatch, app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion
227+
):
112228
"""Create a test client."""
113229
mock_openai_embedding()
230+
mock_openai_chatcompletion()
114231
with TestClient(app) as test_client:
115232
yield test_client
116233

0 commit comments

Comments
 (0)