Skip to content

Commit d3a8256

Browse files
Style fixing for tests/ (#8225)
* enable style check * update rules * allow wildcard imports * fix tests style
1 parent 0dec440 commit d3a8256

File tree

10 files changed

+93
-78
lines changed

10 files changed

+93
-78
lines changed

tests/adapters/test_json_adapter.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import pydantic
44
import pytest
5-
from pydantic import create_model
6-
from litellm.utils import ModelResponse, Message, Choices
5+
from litellm.utils import Choices, Message, ModelResponse
76

87
import dspy
98

@@ -109,7 +108,7 @@ async def test_json_adapter_async_call():
109108

110109

111110
def test_json_adapter_on_pydantic_model():
112-
from litellm.utils import ModelResponse, Message, Choices
111+
from litellm.utils import Choices, Message, ModelResponse
113112

114113
class User(pydantic.BaseModel):
115114
id: int
@@ -170,7 +169,7 @@ class TestSignature(dspy.Signature):
170169
assert expected_input_structure in content
171170

172171
# Assert that system prompt includes output formatting structure
173-
expected_output_structure = ( # noqa: Q000
172+
expected_output_structure = (
174173
"Outputs will be a JSON object with the following fields.\n\n{\n "
175174
'"answer": "{answer} # note: the value you produce must adhere to the JSON schema: '
176175
'{\\"type\\": \\"object\\", \\"properties\\": {\\"analysis\\": {\\"type\\": \\"string\\", \\"title\\": '
@@ -184,7 +183,7 @@ class TestSignature(dspy.Signature):
184183
assert user_message_content is not None
185184

186185
# Assert that the user input data is formatted correctly
187-
expected_input_data = ( # noqa: Q000
186+
expected_input_data = (
188187
'[[ ## user ## ]]\n{"id": 5, "name": "name_test", "email": "email_test"}\n\n[[ ## question ## ]]\n'
189188
"What is the capital of France?\n\n"
190189
)

tests/adapters/test_two_step_adapter.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from unittest import mock
2+
23
import pytest
34

45
import dspy
@@ -9,7 +10,7 @@ class TestSignature(dspy.Signature):
910
question: str = dspy.InputField(desc="The math question to solve")
1011
solution: str = dspy.OutputField(desc="Step by step solution")
1112
answer: float = dspy.OutputField(desc="The final numerical answer")
12-
13+
1314
program = dspy.Predict(TestSignature)
1415

1516
mock_main_lm = mock.MagicMock(spec=dspy.LM)
@@ -30,7 +31,7 @@ class TestSignature(dspy.Signature):
3031
dspy.configure(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm))
3132

3233
result = program(question="What is 5 + 7?")
33-
34+
3435
assert result.answer == 12
3536

3637
# main LM call
@@ -59,22 +60,23 @@ class TestSignature(dspy.Signature):
5960
# assert first message
6061
assert call_kwargs["messages"][0]["role"] == "system"
6162
content = call_kwargs["messages"][0]["content"]
62-
assert "`text` (str)" in content
63-
assert "`solution` (str)" in content
64-
assert "`answer` (float)" in content
63+
assert "`text` (str)" in content
64+
assert "`solution` (str)" in content
65+
assert "`answer` (float)" in content
6566

6667
# assert second message
6768
assert call_kwargs["messages"][1]["role"] == "user"
6869
content = call_kwargs["messages"][1]["content"]
6970
assert "text from main LM" in content
7071

72+
7173
@pytest.mark.asyncio
7274
async def test_two_step_adapter_async_call():
7375
class TestSignature(dspy.Signature):
7476
question: str = dspy.InputField(desc="The math question to solve")
7577
solution: str = dspy.OutputField(desc="Step by step solution")
7678
answer: float = dspy.OutputField(desc="The final numerical answer")
77-
79+
7880
program = dspy.Predict(TestSignature)
7981

8082
mock_main_lm = mock.MagicMock(spec=dspy.LM)
@@ -95,7 +97,7 @@ class TestSignature(dspy.Signature):
9597
dspy.configure(lm=mock_main_lm, adapter=dspy.TwoStepAdapter(extraction_model=mock_extraction_lm))
9698

9799
result = await program.acall(question="What is 5 + 7?")
98-
100+
99101
assert result.answer == 12
100102

101103
# main LM call
@@ -124,37 +126,40 @@ class TestSignature(dspy.Signature):
124126
# assert first message
125127
assert call_kwargs["messages"][0]["role"] == "system"
126128
content = call_kwargs["messages"][0]["content"]
127-
assert "`text` (str)" in content
128-
assert "`solution` (str)" in content
129-
assert "`answer` (float)" in content
129+
assert "`text` (str)" in content
130+
assert "`solution` (str)" in content
131+
assert "`answer` (float)" in content
130132

131133
# assert second message
132134
assert call_kwargs["messages"][1]["role"] == "user"
133135
content = call_kwargs["messages"][1]["content"]
134136
assert "text from main LM" in content
135137

138+
136139
def test_two_step_adapter_parse():
137140
class ComplexSignature(dspy.Signature):
138141
input_text: str = dspy.InputField()
139142
tags: list[str] = dspy.OutputField(desc="List of relevant tags")
140143
confidence: float = dspy.OutputField(desc="Confidence score")
141-
144+
142145
first_response = "main LM response"
143-
146+
144147
mock_extraction_lm = mock.MagicMock(spec=dspy.LM)
145-
mock_extraction_lm.return_value = ["""
148+
mock_extraction_lm.return_value = [
149+
"""
146150
{
147151
"tags": ["AI", "deep learning", "neural networks"],
148152
"confidence": 0.87
149-
}
150-
"""]
153+
}
154+
"""
155+
]
151156
mock_extraction_lm.kwargs = {"temperature": 1.0}
152157
mock_extraction_lm.model = "openai/gpt-4o"
153158
adapter = dspy.TwoStepAdapter(mock_extraction_lm)
154159
dspy.configure(adapter=adapter, lm=mock_extraction_lm)
155160

156161
result = adapter.parse(ComplexSignature, first_response)
157-
162+
158163
assert result["tags"] == ["AI", "deep learning", "neural networks"]
159164
assert result["confidence"] == 0.87
160165

@@ -163,7 +168,7 @@ def test_two_step_adapter_parse_errors():
163168
class TestSignature(dspy.Signature):
164169
question: str = dspy.InputField()
165170
answer: str = dspy.OutputField()
166-
171+
167172
first_response = "main LM response"
168173

169174
mock_extraction_lm = mock.MagicMock(spec=dspy.LM)
@@ -172,6 +177,6 @@ class TestSignature(dspy.Signature):
172177
mock_extraction_lm.model = "openai/gpt-4o"
173178

174179
adapter = dspy.TwoStepAdapter(mock_extraction_lm)
175-
180+
176181
with pytest.raises(ValueError, match="Failed to parse response"):
177-
adapter.parse(TestSignature, first_response)
182+
adapter.parse(TestSignature, first_response)

tests/caching/test_litellm_cache.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
import dspy
10-
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs
10+
from tests.test_utils.server import read_litellm_test_server_request_logs
1111

1212

1313
@pytest.fixture()
@@ -158,8 +158,8 @@ class NonJsonSerializable:
158158

159159

160160
def test_lms_called_expected_number_of_times_for_cache_key_generation_failures():
161-
with pytest.raises(Exception), patch("litellm.completion") as mock_completion:
162-
mock_completion.side_effect = Exception("Mocked exception")
161+
with pytest.raises(RuntimeError), patch("litellm.completion") as mock_completion:
162+
mock_completion.side_effect = RuntimeError("Mocked exception")
163163
lm = dspy.LM(
164164
model="openai/dspy-test-model",
165165
api_base="fakebase",

tests/clients/test_cache.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import os
2-
import shutil
1+
from dataclasses import dataclass
32
from unittest.mock import patch
43

54
import pydantic
@@ -8,7 +7,6 @@
87
from diskcache import FanoutCache
98

109
from dspy.clients.cache import Cache
11-
from dataclasses import dataclass
1210

1311

1412
@dataclass

tests/clients/test_databricks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
manual testing.
55
"""
66

7+
import pytest
8+
9+
import dspy
710
from dspy.clients.databricks import (
811
DatabricksProvider,
9-
_create_directory_in_databricks_unity_catalog,
1012
TrainingJobDatabricks,
13+
_create_directory_in_databricks_unity_catalog,
1114
)
1215

13-
import pytest
14-
import dspy
15-
1616
try:
1717
from databricks.sdk import WorkspaceClient
1818

@@ -66,7 +66,7 @@ def test_create_finetuning_job():
6666

6767
job = TrainingJobDatabricks()
6868

69-
finetuned_model = DatabricksProvider.finetune(
69+
DatabricksProvider.finetune(
7070
job=job,
7171
model="meta-llama/Llama-3.2-1B",
7272
train_data=fake_training_data,

tests/clients/test_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import pytest
2-
from unittest.mock import Mock, patch
1+
from unittest.mock import patch
2+
33
import numpy as np
4+
import pytest
45

5-
from dspy.clients.embedding import Embedder
66
import dspy
7+
from dspy.clients.embedding import Embedder
78

89

910
# Mock response format similar to litellm's embedding response.

tests/clients/test_inspect_global_history.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import pytest
2-
from dspy.utils.dummies import DummyLM
3-
from dspy.clients.base_lm import GLOBAL_HISTORY
2+
43
import dspy
4+
from dspy.clients.base_lm import GLOBAL_HISTORY
5+
from dspy.utils.dummies import DummyLM
6+
57

68
@pytest.fixture(autouse=True)
79
def clear_history():
810
GLOBAL_HISTORY.clear()
911
yield
1012

13+
1114
def test_inspect_history_basic(capsys):
1215
# Configure a DummyLM with some predefined responses
1316
lm = DummyLM([{"response": "Hello"}, {"response": "How are you?"}])
1417
dspy.settings.configure(lm=lm)
15-
18+
1619
# Make some calls to generate history
1720
predictor = dspy.Predict("query: str -> response: str")
1821
predictor(query="Hi")
1922
predictor(query="What's up?")
20-
23+
2124
# Test inspecting all history
2225
history = GLOBAL_HISTORY
2326
print(capsys)
@@ -26,45 +29,48 @@ def test_inspect_history_basic(capsys):
2629
assert all(isinstance(entry, dict) for entry in history)
2730
assert all("messages" in entry for entry in history)
2831

32+
2933
def test_inspect_history_with_n(capsys):
3034
"""Test that inspect_history works with n
3135
Random failures in this test most likely mean you are printing messages somewhere
3236
"""
3337
lm = DummyLM([{"response": "One"}, {"response": "Two"}, {"response": "Three"}])
3438
dspy.settings.configure(lm=lm)
35-
39+
3640
# Generate some history
3741
predictor = dspy.Predict("query: str -> response: str")
3842
predictor(query="First")
3943
predictor(query="Second")
4044
predictor(query="Third")
41-
45+
4246
dspy.inspect_history(n=2)
4347
# Test getting last 2 entries
4448
out, err = capsys.readouterr()
45-
assert not "First" in out
49+
assert "First" not in out
4650
assert "Second" in out
4751
assert "Third" in out
4852

53+
4954
def test_inspect_empty_history(capsys):
5055
# Configure fresh DummyLM
5156
lm = DummyLM([])
5257
dspy.settings.configure(lm=lm)
53-
58+
5459
# Test inspecting empty history
5560
dspy.inspect_history()
5661
history = GLOBAL_HISTORY
5762
assert len(history) == 0
5863
assert isinstance(history, list)
5964

65+
6066
def test_inspect_history_n_larger_than_history(capsys):
6167
lm = DummyLM([{"response": "First"}, {"response": "Second"}])
6268
dspy.settings.configure(lm=lm)
63-
69+
6470
predictor = dspy.Predict("query: str -> response: str")
6571
predictor(query="Query 1")
6672
predictor(query="Query 2")
67-
73+
6874
# Request more entries than exist
6975
dspy.inspect_history(n=5)
7076
history = GLOBAL_HISTORY

tests/clients/test_lm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1+
import time
12
from unittest import mock
23
from unittest.mock import patch
34

4-
import time
55
import litellm
66
import pydantic
77
import pytest
8+
from litellm.utils import Choices, Message, ModelResponse
89
from openai import RateLimitError
910

1011
import dspy
1112
from dspy.utils.usage_tracker import track_usage
12-
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs
13-
from litellm.utils import ModelResponse, Message, Choices
1413

1514

1615
def test_chat_lms_can_be_queried(litellm_test_server):
@@ -127,7 +126,10 @@ def test_lm_calls_support_callables(litellm_test_server):
127126
api_base, _ = litellm_test_server
128127

129128
with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion:
130-
azure_ad_token_provider = lambda *args, **kwargs: None
129+
130+
def azure_ad_token_provider(*args, **kwargs):
131+
return None
132+
131133
lm_with_callable = dspy.LM(
132134
model="openai/dspy-test-model",
133135
api_base=api_base,
@@ -315,12 +317,12 @@ def test_logprobs_included_when_requested():
315317
},
316318
]
317319
}
318-
assert mock_completion.call_args.kwargs["logprobs"] == True
320+
assert mock_completion.call_args.kwargs["logprobs"]
319321

320322

321323
@pytest.mark.asyncio
322324
async def test_async_lm_call():
323-
from litellm.utils import ModelResponse, Message, Choices
325+
from litellm.utils import Choices, Message, ModelResponse
324326

325327
mock_response = ModelResponse(choices=[Choices(message=Message(content="answer"))], model="openai/gpt-4o-mini")
326328

0 commit comments

Comments
 (0)