|
1 | 1 | from typing import Literal
|
2 | 2 | from unittest import mock
|
3 | 3 |
|
| 4 | +import pydantic |
4 | 5 | import pytest
|
5 | 6 |
|
6 | 7 | import dspy
|
@@ -97,6 +98,106 @@ async def test_chat_adapter_async_call():
|
97 | 98 | assert result == [{"answer": "Paris"}]
|
98 | 99 |
|
99 | 100 |
|
| 101 | +def test_chat_adapter_with_pydantic_models(): |
| 102 | + """ |
| 103 | + This test verifies that ChatAdapter can handle different input and output field types, both basic and nested. |
| 104 | + """ |
| 105 | + |
| 106 | + class DogClass(pydantic.BaseModel): |
| 107 | + dog_breeds: list[str] = pydantic.Field(description="List of the breeds of dogs") |
| 108 | + num_dogs: int = pydantic.Field(description="Number of dogs the owner has", ge=0, le=10) |
| 109 | + |
| 110 | + class PetOwner(pydantic.BaseModel): |
| 111 | + name: str = pydantic.Field(description="Name of the owner") |
| 112 | + num_pets: int = pydantic.Field(description="Amount of pets the owner has", ge=0, le=100) |
| 113 | + dogs: DogClass = pydantic.Field(description="Nested Pydantic class with dog specific information ") |
| 114 | + |
| 115 | + class Answer(pydantic.BaseModel): |
| 116 | + result: str |
| 117 | + analysis: str |
| 118 | + |
| 119 | + class TestSignature(dspy.Signature): |
| 120 | + owner: PetOwner = dspy.InputField() |
| 121 | + question: str = dspy.InputField() |
| 122 | + output: Answer = dspy.OutputField() |
| 123 | + |
| 124 | + dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.ChatAdapter()) |
| 125 | + program = dspy.Predict(TestSignature) |
| 126 | + |
| 127 | + with mock.patch("litellm.completion") as mock_completion: |
| 128 | + program( |
| 129 | + owner=PetOwner(name="John", num_pets=5, dogs=DogClass(dog_breeds=["labrador", "chihuahua"], num_dogs=2)), |
| 130 | + question="How many non-dog pets does John have?", |
| 131 | + ) |
| 132 | + |
| 133 | + mock_completion.assert_called_once() |
| 134 | + _, call_kwargs = mock_completion.call_args |
| 135 | + |
| 136 | + system_content = call_kwargs["messages"][0]["content"] |
| 137 | + user_content = call_kwargs["messages"][1]["content"] |
| 138 | + assert "1. `owner` (PetOwner)" in system_content |
| 139 | + assert "2. `question` (str)" in system_content |
| 140 | + assert "1. `output` (Answer)" in system_content |
| 141 | + |
| 142 | + assert "name" in user_content |
| 143 | + assert "num_pets" in user_content |
| 144 | + assert "dogs" in user_content |
| 145 | + assert "dog_breeds" in user_content |
| 146 | + assert "num_dogs" in user_content |
| 147 | + assert "How many non-dog pets does John have?" in user_content |
| 148 | + |
| 149 | + |
| 150 | +def test_chat_adapter_signature_information(): |
| 151 | + """ |
| 152 | + This test ensures that the signature information sent to the LM follows an expected format. |
| 153 | + """ |
| 154 | + |
| 155 | + class TestSignature(dspy.Signature): |
| 156 | + input1: str = dspy.InputField(desc="String Input") |
| 157 | + input2: int = dspy.InputField(desc="Integer Input") |
| 158 | + output: str = dspy.OutputField(desc="String Output") |
| 159 | + |
| 160 | + dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.ChatAdapter()) |
| 161 | + program = dspy.Predict(TestSignature) |
| 162 | + |
| 163 | + with mock.patch("litellm.completion") as mock_completion: |
| 164 | + program(input1="Test", input2=11) |
| 165 | + |
| 166 | + mock_completion.assert_called_once() |
| 167 | + _, call_kwargs = mock_completion.call_args |
| 168 | + |
| 169 | + assert len(call_kwargs["messages"]) == 2 |
| 170 | + assert call_kwargs["messages"][0]["role"] == "system" |
| 171 | + assert call_kwargs["messages"][1]["role"] == "user" |
| 172 | + |
| 173 | + system_content = call_kwargs["messages"][0]["content"] |
| 174 | + user_content = call_kwargs["messages"][1]["content"] |
| 175 | + |
| 176 | + assert "1. `input1` (str)" in system_content |
| 177 | + assert "2. `input2` (int)" in system_content |
| 178 | + assert "1. `output` (str)" in system_content |
| 179 | + assert "[[ ## input1 ## ]]\n{input1}" in system_content |
| 180 | + assert "[[ ## input2 ## ]]\n{input2}" in system_content |
| 181 | + assert "[[ ## output ## ]]\n{output}" in system_content |
| 182 | + assert "[[ ## completed ## ]]" in system_content |
| 183 | + |
| 184 | + assert "[[ ## input1 ## ]]" in user_content |
| 185 | + assert "[[ ## input2 ## ]]" in user_content |
| 186 | + assert "[[ ## output ## ]]" in user_content |
| 187 | + assert "[[ ## completed ## ]]" in user_content |
| 188 | + |
| 189 | + |
| 190 | +def test_chat_adapter_exception_raised_on_failure(): |
| 191 | + """ |
| 192 | + This test ensures that on an error, ChatAdapter raises an explicit exception. |
| 193 | + """ |
| 194 | + signature = dspy.make_signature("question->answer") |
| 195 | + adapter = dspy.ChatAdapter() |
| 196 | + invalid_completion = "{'output':'mismatched value'}" |
| 197 | + with pytest.raises(dspy.utils.exceptions.AdapterParseError, match="Adapter ChatAdapter failed to parse*"): |
| 198 | + adapter.parse(signature, invalid_completion) |
| 199 | + |
| 200 | + |
100 | 201 | def test_chat_adapter_formats_image():
|
101 | 202 | # Test basic image formatting
|
102 | 203 | image = dspy.Image(url="https://example.com/image.jpg")
|
|
0 commit comments