|
8 | 8 | import pytest_asyncio
|
9 | 9 | from fastapi.testclient import TestClient
|
10 | 10 | from openai.types import CreateEmbeddingResponse, Embedding
|
| 11 | +from openai.types.chat import ChatCompletion, ChatCompletionChunk |
| 12 | +from openai.types.chat.chat_completion import ( |
| 13 | + ChatCompletionMessage, |
| 14 | + Choice, |
| 15 | +) |
11 | 16 | from openai.types.create_embedding_response import Usage
|
12 | 17 | from sqlalchemy.ext.asyncio import async_sessionmaker
|
13 | 18 |
|
@@ -107,10 +112,122 @@ def patch():
|
107 | 112 | return patch
|
108 | 113 |
|
109 | 114 |
|
| 115 | +@pytest.fixture |
| 116 | +def mock_openai_chatcompletion(monkeypatch): |
| 117 | + class AsyncChatCompletionIterator: |
| 118 | + def __init__(self, answer: str): |
| 119 | + chunk_id = "test-id" |
| 120 | + model = "gpt-35-turbo" |
| 121 | + self.responses = [ |
| 122 | + {"object": "chat.completion.chunk", "choices": [], "id": chunk_id, "model": model, "created": 1}, |
| 123 | + { |
| 124 | + "object": "chat.completion.chunk", |
| 125 | + "choices": [{"delta": {"role": "assistant"}, "index": 0, "finish_reason": None}], |
| 126 | + "id": chunk_id, |
| 127 | + "model": model, |
| 128 | + "created": 1, |
| 129 | + }, |
| 130 | + ] |
| 131 | + # Split at << to simulate chunked responses |
| 132 | + if answer.find("<<") > -1: |
| 133 | + parts = answer.split("<<") |
| 134 | + self.responses.append( |
| 135 | + { |
| 136 | + "object": "chat.completion.chunk", |
| 137 | + "choices": [ |
| 138 | + { |
| 139 | + "delta": {"role": "assistant", "content": parts[0] + "<<"}, |
| 140 | + "index": 0, |
| 141 | + "finish_reason": None, |
| 142 | + } |
| 143 | + ], |
| 144 | + "id": chunk_id, |
| 145 | + "model": model, |
| 146 | + "created": 1, |
| 147 | + } |
| 148 | + ) |
| 149 | + self.responses.append( |
| 150 | + { |
| 151 | + "object": "chat.completion.chunk", |
| 152 | + "choices": [ |
| 153 | + {"delta": {"role": "assistant", "content": parts[1]}, "index": 0, "finish_reason": None} |
| 154 | + ], |
| 155 | + "id": chunk_id, |
| 156 | + "model": model, |
| 157 | + "created": 1, |
| 158 | + } |
| 159 | + ) |
| 160 | + self.responses.append( |
| 161 | + { |
| 162 | + "object": "chat.completion.chunk", |
| 163 | + "choices": [{"delta": {"role": None, "content": None}, "index": 0, "finish_reason": "stop"}], |
| 164 | + "id": chunk_id, |
| 165 | + "model": model, |
| 166 | + "created": 1, |
| 167 | + } |
| 168 | + ) |
| 169 | + else: |
| 170 | + self.responses.append( |
| 171 | + { |
| 172 | + "object": "chat.completion.chunk", |
| 173 | + "choices": [{"delta": {"content": answer}, "index": 0, "finish_reason": None}], |
| 174 | + "id": chunk_id, |
| 175 | + "model": model, |
| 176 | + "created": 1, |
| 177 | + } |
| 178 | + ) |
| 179 | + |
| 180 | + def __aiter__(self): |
| 181 | + return self |
| 182 | + |
| 183 | + async def __anext__(self): |
| 184 | + if self.responses: |
| 185 | + return ChatCompletionChunk.model_validate(self.responses.pop(0)) |
| 186 | + else: |
| 187 | + raise StopAsyncIteration |
| 188 | + |
| 189 | + async def mock_acreate(*args, **kwargs): |
| 190 | + messages = kwargs["messages"] |
| 191 | + last_question = messages[-1]["content"] |
| 192 | + if last_question == "Generate search query for: What is the capital of France?": |
| 193 | + answer = "capital of France" |
| 194 | + elif last_question == "Generate search query for: Are interest rates high?": |
| 195 | + answer = "interest rates" |
| 196 | + elif isinstance(last_question, list) and last_question[2].get("image_url"): |
| 197 | + answer = "From the provided sources, the impact of interest rates and GDP growth on " |
| 198 | + "financial markets can be observed through the line graph. [Financial Market Analysis Report 2023-7.png]" |
| 199 | + else: |
| 200 | + answer = "The capital of France is Paris. [Benefit_Options-2.pdf]." |
| 201 | + if messages[0]["content"].find("Generate 3 very brief follow-up questions") > -1: |
| 202 | + answer = "The capital of France is Paris. [Benefit_Options-2.pdf]. <<What is the capital of Spain?>>" |
| 203 | + if "stream" in kwargs and kwargs["stream"] is True: |
| 204 | + return AsyncChatCompletionIterator(answer) |
| 205 | + else: |
| 206 | + return ChatCompletion( |
| 207 | + object="chat.completion", |
| 208 | + choices=[ |
| 209 | + Choice( |
| 210 | + message=ChatCompletionMessage(role="assistant", content=answer), finish_reason="stop", index=0 |
| 211 | + ) |
| 212 | + ], |
| 213 | + id="test-123", |
| 214 | + created=0, |
| 215 | + model="test-model", |
| 216 | + ) |
| 217 | + |
| 218 | + def patch(): |
| 219 | + monkeypatch.setattr(openai.resources.chat.completions.AsyncCompletions, "create", mock_acreate) |
| 220 | + |
| 221 | + return patch |
| 222 | + |
| 223 | + |
110 | 224 | @pytest_asyncio.fixture(scope="function")
|
111 |
| -async def test_client(monkeypatch, app, mock_default_azure_credential, mock_openai_embedding): |
| 225 | +async def test_client( |
| 226 | + monkeypatch, app, mock_default_azure_credential, mock_openai_embedding, mock_openai_chatcompletion |
| 227 | +): |
112 | 228 | """Create a test client."""
|
113 | 229 | mock_openai_embedding()
|
| 230 | + mock_openai_chatcompletion() |
114 | 231 | with TestClient(app) as test_client:
|
115 | 232 | yield test_client
|
116 | 233 |
|
|
0 commit comments