Skip to content

Commit 1699000

Browse files
Increased test coverage for ChatAdapter (#8168)
* test(ChatAdapter): Increased test coverage for ChatAdapter * test(ChatAdapter): Modified input for pydantic test * test(ChatAdapter): Slight change to exception test * test(ChatAdapter): Modified pydantic test case to a more user-friendly format * test(ChatAdapter): removed some unneeded print statements * test(ChatAdapter): Added descriptive comments * fix test --------- Co-authored-by: chenmoneygithub <chen.qian@databricks.com>
1 parent 39387cb commit 1699000

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

tests/adapters/test_chat_adapter.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Literal
22
from unittest import mock
33

4+
import pydantic
45
import pytest
56

67
import dspy
@@ -97,6 +98,106 @@ async def test_chat_adapter_async_call():
9798
assert result == [{"answer": "Paris"}]
9899

99100

101+
def test_chat_adapter_with_pydantic_models():
102+
"""
103+
This test verifies that ChatAdapter can handle different input and output field types, both basic and nested.
104+
"""
105+
106+
class DogClass(pydantic.BaseModel):
107+
dog_breeds: list[str] = pydantic.Field(description="List of the breeds of dogs")
108+
num_dogs: int = pydantic.Field(description="Number of dogs the owner has", ge=0, le=10)
109+
110+
class PetOwner(pydantic.BaseModel):
111+
name: str = pydantic.Field(description="Name of the owner")
112+
num_pets: int = pydantic.Field(description="Amount of pets the owner has", ge=0, le=100)
113+
dogs: DogClass = pydantic.Field(description="Nested Pydantic class with dog specific information ")
114+
115+
class Answer(pydantic.BaseModel):
116+
result: str
117+
analysis: str
118+
119+
class TestSignature(dspy.Signature):
120+
owner: PetOwner = dspy.InputField()
121+
question: str = dspy.InputField()
122+
output: Answer = dspy.OutputField()
123+
124+
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.ChatAdapter())
125+
program = dspy.Predict(TestSignature)
126+
127+
with mock.patch("litellm.completion") as mock_completion:
128+
program(
129+
owner=PetOwner(name="John", num_pets=5, dogs=DogClass(dog_breeds=["labrador", "chihuahua"], num_dogs=2)),
130+
question="How many non-dog pets does John have?",
131+
)
132+
133+
mock_completion.assert_called_once()
134+
_, call_kwargs = mock_completion.call_args
135+
136+
system_content = call_kwargs["messages"][0]["content"]
137+
user_content = call_kwargs["messages"][1]["content"]
138+
assert "1. `owner` (PetOwner)" in system_content
139+
assert "2. `question` (str)" in system_content
140+
assert "1. `output` (Answer)" in system_content
141+
142+
assert "name" in user_content
143+
assert "num_pets" in user_content
144+
assert "dogs" in user_content
145+
assert "dog_breeds" in user_content
146+
assert "num_dogs" in user_content
147+
assert "How many non-dog pets does John have?" in user_content
148+
149+
150+
def test_chat_adapter_signature_information():
151+
"""
152+
This test ensures that the signature information sent to the LM follows an expected format.
153+
"""
154+
155+
class TestSignature(dspy.Signature):
156+
input1: str = dspy.InputField(desc="String Input")
157+
input2: int = dspy.InputField(desc="Integer Input")
158+
output: str = dspy.OutputField(desc="String Output")
159+
160+
dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.ChatAdapter())
161+
program = dspy.Predict(TestSignature)
162+
163+
with mock.patch("litellm.completion") as mock_completion:
164+
program(input1="Test", input2=11)
165+
166+
mock_completion.assert_called_once()
167+
_, call_kwargs = mock_completion.call_args
168+
169+
assert len(call_kwargs["messages"]) == 2
170+
assert call_kwargs["messages"][0]["role"] == "system"
171+
assert call_kwargs["messages"][1]["role"] == "user"
172+
173+
system_content = call_kwargs["messages"][0]["content"]
174+
user_content = call_kwargs["messages"][1]["content"]
175+
176+
assert "1. `input1` (str)" in system_content
177+
assert "2. `input2` (int)" in system_content
178+
assert "1. `output` (str)" in system_content
179+
assert "[[ ## input1 ## ]]\n{input1}" in system_content
180+
assert "[[ ## input2 ## ]]\n{input2}" in system_content
181+
assert "[[ ## output ## ]]\n{output}" in system_content
182+
assert "[[ ## completed ## ]]" in system_content
183+
184+
assert "[[ ## input1 ## ]]" in user_content
185+
assert "[[ ## input2 ## ]]" in user_content
186+
assert "[[ ## output ## ]]" in user_content
187+
assert "[[ ## completed ## ]]" in user_content
188+
189+
190+
def test_chat_adapter_exception_raised_on_failure():
191+
"""
192+
This test ensures that on an error, ChatAdapter raises an explicit exception.
193+
"""
194+
signature = dspy.make_signature("question->answer")
195+
adapter = dspy.ChatAdapter()
196+
invalid_completion = "{'output':'mismatched value'}"
197+
with pytest.raises(dspy.utils.exceptions.AdapterParseError, match="Adapter ChatAdapter failed to parse*"):
198+
adapter.parse(signature, invalid_completion)
199+
200+
100201
def test_chat_adapter_formats_image():
101202
# Test basic image formatting
102203
image = dspy.Image(url="https://example.com/image.jpg")

0 commit comments

Comments
 (0)