Skip to content

Commit b4d1a7e

Browse files
authored
Run unit tests with real LLM calls (#8486)
* use real LLM for unit tests * use ollama * use Llama 3.2 3b * add verbose option * split test into a separate job * remove LLM pulling * fix option name * rename env var
1 parent f5a75e5 commit b4d1a7e

File tree

5 files changed

+78
-39
lines changed

5 files changed

+78
-39
lines changed

.github/workflows/run_tests.yml

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,49 @@ jobs:
8282
with:
8383
args: check --fix-only
8484
- name: Run tests with pytest
85-
run: uv run -p .venv pytest tests/
85+
run: uv run -p .venv pytest -vv tests/
8686
- name: Install optional dependencies
8787
run: uv sync -p .venv --extra dev --extra test_extras
8888
- name: Run extra tests
8989
run: uv run -p .venv pytest tests/ -m extra --extra
90+
91+
llm_call_test:
92+
name: Run Tests with Real LM
93+
runs-on: ubuntu-latest
94+
services:
95+
ollama:
96+
image: ollama/ollama:latest
97+
ports:
98+
- 11434:11434
99+
steps:
100+
- uses: actions/checkout@v4
101+
- uses: actions/setup-python@v5
102+
with:
103+
python-version: 3.11
104+
- name: Install uv with caching
105+
uses: astral-sh/setup-uv@v5
106+
with:
107+
enable-cache: true
108+
cache-dependency-glob: |
109+
**/pyproject.toml
110+
**/uv.lock
111+
- name: Create and activate virtual environment
112+
run: |
113+
uv venv .venv
114+
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
115+
- name: Install dependencies
116+
run: |
117+
uv sync --dev -p .venv --extra dev
118+
uv pip list
119+
- name: Pull LLM
120+
run: |
121+
timeout 60 bash -c 'until curl -f http://localhost:11434/api/version; do sleep 2; done'
122+
curl -X POST http://localhost:11434/api/pull \
123+
-H "Content-Type: application/json" \
124+
-d '{"name": "llama3.2:3b"}'
125+
echo "LM_FOR_TEST=ollama/llama3.2:3b" >> $GITHUB_ENV
126+
- name: Run tests
127+
run: uv run -p .venv pytest -m llm_call --llm_call -vv --durations=5 tests/
90128

91129
build_package:
92130
name: Build Package

tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import copy
2+
import os
23

34
import pytest
45

56
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs # noqa: F401
67

7-
SKIP_DEFAULT_FLAGS = ["reliability", "extra"]
8+
SKIP_DEFAULT_FLAGS = ["reliability", "extra", "llm_call"]
89

910

1011
@pytest.fixture(autouse=True)
@@ -49,3 +50,11 @@ def pytest_collection_modifyitems(config, items):
4950
for item in items:
5051
if flag in item.keywords:
5152
item.add_marker(skip_mark)
53+
54+
55+
@pytest.fixture
56+
def lm_for_test():
57+
model = os.environ.get("LM_FOR_TEST", None)
58+
if model is None:
59+
pytest.skip("LM_FOR_TEST is not set in the environment variables")
60+
return model

tests/primitives/test_base_module.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,30 @@ def emit(self, record):
230230
logger.removeHandler(handler)
231231

232232

233-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
234-
def test_single_module_call_with_usage_tracker():
235-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
233+
@pytest.mark.llm_call
234+
def test_single_module_call_with_usage_tracker(lm_for_test):
235+
dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True)
236236

237237
predict = dspy.ChainOfThought("question -> answer")
238238
output = predict(question="What is the capital of France?")
239239

240240
lm_usage = output.get_lm_usage()
241241
assert len(lm_usage) == 1
242-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
243-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
244-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
242+
assert lm_usage[lm_for_test]["prompt_tokens"] > 0
243+
assert lm_usage[lm_for_test]["completion_tokens"] > 0
244+
assert lm_usage[lm_for_test]["total_tokens"] > 0
245245

246246
# Test no usage being tracked when cache is enabled
247-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=True), track_usage=True)
247+
dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=True), track_usage=True)
248248
for _ in range(2):
249249
output = predict(question="What is the capital of France?")
250250

251251
assert len(output.get_lm_usage()) == 0
252252

253253

254-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
255-
def test_multi_module_call_with_usage_tracker():
256-
dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), track_usage=True)
254+
@pytest.mark.llm_call
255+
def test_multi_module_call_with_usage_tracker(lm_for_test):
256+
dspy.settings.configure(lm=dspy.LM(lm_for_test, cache=False), track_usage=True)
257257

258258
class MyProgram(dspy.Module):
259259
def __init__(self):
@@ -270,12 +270,13 @@ def __call__(self, question: str) -> str:
270270

271271
lm_usage = output.get_lm_usage()
272272
assert len(lm_usage) == 1
273-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
274-
assert lm_usage["openai/gpt-4o-mini"]["prompt_tokens"] > 0
275-
assert lm_usage["openai/gpt-4o-mini"]["completion_tokens"] > 0
276-
assert lm_usage["openai/gpt-4o-mini"]["total_tokens"] > 0
273+
assert lm_usage[lm_for_test]["prompt_tokens"] > 0
274+
assert lm_usage[lm_for_test]["prompt_tokens"] > 0
275+
assert lm_usage[lm_for_test]["completion_tokens"] > 0
276+
assert lm_usage[lm_for_test]["total_tokens"] > 0
277277

278278

279+
# TODO: prepare second model for testing this unit test in ci
279280
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Skip the test if OPENAI_API_KEY is not set.")
280281
def test_usage_tracker_in_parallel():
281282
class MyProgram(dspy.Module):

tests/streaming/test_streaming.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
import time
43
from unittest import mock
54
from unittest.mock import AsyncMock
@@ -131,9 +130,9 @@ def module_start_status_message(self, instance, inputs):
131130
assert status_messages[2].message == "Predict starting!"
132131

133132

134-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
133+
@pytest.mark.llm_call
135134
@pytest.mark.anyio
136-
async def test_stream_listener_chat_adapter():
135+
async def test_stream_listener_chat_adapter(lm_for_test):
137136
class MyProgram(dspy.Module):
138137
def __init__(self):
139138
self.predict1 = dspy.Predict("question->answer")
@@ -154,7 +153,7 @@ def __call__(self, x: str, **kwargs):
154153
include_final_prediction_in_output_stream=False,
155154
)
156155
# Turn off the cache to ensure the stream is produced.
157-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
156+
with dspy.context(lm=dspy.LM(lm_for_test, cache=False)):
158157
output = program(x="why did a chicken cross the kitchen?")
159158
all_chunks = []
160159
async for value in output:
@@ -194,9 +193,9 @@ async def acall(self, x: str):
194193
assert status_messages[1].message == "Tool calling finished! Querying the LLM with tool calling results..."
195194

196195

197-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
196+
@pytest.mark.llm_call
198197
@pytest.mark.anyio
199-
async def test_stream_listener_json_adapter():
198+
async def test_stream_listener_json_adapter(lm_for_test):
200199
class MyProgram(dspy.Module):
201200
def __init__(self):
202201
self.predict1 = dspy.Predict("question->answer")
@@ -217,7 +216,7 @@ def __call__(self, x: str, **kwargs):
217216
include_final_prediction_in_output_stream=False,
218217
)
219218
# Turn off the cache to ensure the stream is produced.
220-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()):
219+
with dspy.context(lm=dspy.LM(lm_for_test, cache=False), adapter=dspy.JSONAdapter()):
221220
output = program(x="why did a chicken cross the kitchen?")
222221
all_chunks = []
223222
async for value in output:
@@ -261,8 +260,8 @@ async def gpt_4o_mini_stream(*args, **kwargs):
261260
assert all_chunks[0].chunk == "How are you doing?"
262261

263262

264-
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OpenAI API key not found in environment variables")
265-
def test_sync_streaming():
263+
@pytest.mark.llm_call
264+
def test_sync_streaming(lm_for_test):
266265
class MyProgram(dspy.Module):
267266
def __init__(self):
268267
self.predict1 = dspy.Predict("question->answer")
@@ -284,7 +283,7 @@ def __call__(self, x: str, **kwargs):
284283
async_streaming=False,
285284
)
286285
# Turn off the cache to ensure the stream is produced.
287-
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False)):
286+
with dspy.context(lm=dspy.LM(lm_for_test, cache=False)):
288287
output = program(x="why did a chicken cross the kitchen?")
289288
all_chunks = []
290289
for value in output:

tests/utils/test_usage_tracker.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import os
2-
3-
import pytest
4-
51
import dspy
62
from dspy.utils.usage_tracker import UsageTracker, track_usage
73

@@ -137,12 +133,8 @@ def test_track_usage_with_multiple_models():
137133
assert total_usage["gpt-3.5-turbo"]["total_tokens"] == 900
138134

139135

140-
@pytest.mark.skipif(
141-
not os.getenv("OPENAI_API_KEY"),
142-
reason="Skip the test if OPENAI_API_KEY is not set.",
143-
)
144-
def test_track_usage_context_manager():
145-
lm = dspy.LM("openai/gpt-4o-mini", cache=False)
136+
def test_track_usage_context_manager(lm_for_test):
137+
lm = dspy.LM(lm_for_test, cache=False)
146138
dspy.settings.configure(lm=lm)
147139

148140
predict = dspy.ChainOfThought("question -> answer")
@@ -151,12 +143,12 @@ def test_track_usage_context_manager():
151143
predict(question="What is the capital of Italy?")
152144

153145
assert len(tracker.usage_data) > 0
154-
assert len(tracker.usage_data["openai/gpt-4o-mini"]) == 2
146+
assert len(tracker.usage_data[lm_for_test]) == 2
155147

156148
total_usage = tracker.get_total_tokens()
157-
assert "openai/gpt-4o-mini" in total_usage
149+
assert lm_for_test in total_usage
158150
assert len(total_usage.keys()) == 1
159-
assert isinstance(total_usage["openai/gpt-4o-mini"], dict)
151+
assert isinstance(total_usage[lm_for_test], dict)
160152

161153

162154
def test_merge_usage_entries_with_new_keys():

0 commit comments

Comments
 (0)