|
| 1 | +import pydantic |
| 2 | +import pytest |
| 3 | + |
| 4 | +import dspy |
| 5 | +from dspy.adapters.chat_adapter import FieldInfoWithName |
| 6 | +from dspy.adapters.xml_adapter import XMLAdapter |
| 7 | + |
| 8 | + |
| 9 | +def test_xml_adapter_format_and_parse_basic(): |
| 10 | + class TestSignature(dspy.Signature): |
| 11 | + question: str = dspy.InputField() |
| 12 | + answer: str = dspy.OutputField() |
| 13 | + |
| 14 | + adapter = XMLAdapter() |
| 15 | + # Format output fields as XML |
| 16 | + fields_with_values = {FieldInfoWithName(name="answer", info=TestSignature.output_fields["answer"]): "Paris"} |
| 17 | + xml = adapter.format_field_with_value(fields_with_values) |
| 18 | + assert xml.strip() == "<answer>\nParis\n</answer>" |
| 19 | + |
| 20 | + # Parse XML output |
| 21 | + completion = "<answer>Paris</answer>" |
| 22 | + parsed = adapter.parse(TestSignature, completion) |
| 23 | + assert parsed == {"answer": "Paris"} |
| 24 | + |
| 25 | + |
| 26 | +def test_xml_adapter_parse_multiple_fields(): |
| 27 | + class TestSignature(dspy.Signature): |
| 28 | + question: str = dspy.InputField() |
| 29 | + answer: str = dspy.OutputField() |
| 30 | + explanation: str = dspy.OutputField() |
| 31 | + |
| 32 | + adapter = XMLAdapter() |
| 33 | + completion = """ |
| 34 | +<answer>Paris</answer> |
| 35 | +<explanation>The capital of France is Paris.</explanation> |
| 36 | +""" |
| 37 | + parsed = adapter.parse(TestSignature, completion) |
| 38 | + assert parsed == {"answer": "Paris", "explanation": "The capital of France is Paris."} |
| 39 | + |
| 40 | + |
| 41 | +def test_xml_adapter_parse_raises_on_missing_field(): |
| 42 | + class TestSignature(dspy.Signature): |
| 43 | + question: str = dspy.InputField() |
| 44 | + answer: str = dspy.OutputField() |
| 45 | + explanation: str = dspy.OutputField() |
| 46 | + |
| 47 | + adapter = XMLAdapter() |
| 48 | + completion = "<answer>Paris</answer>" |
| 49 | + with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: |
| 50 | + adapter.parse(TestSignature, completion) |
| 51 | + assert e.value.adapter_name == "XMLAdapter" |
| 52 | + assert e.value.signature == TestSignature |
| 53 | + assert e.value.lm_response == "<answer>Paris</answer>" |
| 54 | + assert "explanation" in str(e.value) |
| 55 | + |
| 56 | + |
| 57 | +def test_xml_adapter_parse_casts_types(): |
| 58 | + class TestSignature(dspy.Signature): |
| 59 | + number: int = dspy.OutputField() |
| 60 | + flag: bool = dspy.OutputField() |
| 61 | + |
| 62 | + adapter = XMLAdapter() |
| 63 | + completion = """ |
| 64 | +<number>42</number> |
| 65 | +<flag>true</flag> |
| 66 | +""" |
| 67 | + parsed = adapter.parse(TestSignature, completion) |
| 68 | + assert parsed == {"number": 42, "flag": True} |
| 69 | + |
| 70 | + |
| 71 | +def test_xml_adapter_parse_raises_on_type_error(): |
| 72 | + class TestSignature(dspy.Signature): |
| 73 | + number: int = dspy.OutputField() |
| 74 | + |
| 75 | + adapter = XMLAdapter() |
| 76 | + completion = "<number>not_a_number</number>" |
| 77 | + with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: |
| 78 | + adapter.parse(TestSignature, completion) |
| 79 | + assert "Failed to parse field" in str(e.value) |
| 80 | + |
| 81 | + |
| 82 | +def test_xml_adapter_format_and_parse_nested_model(): |
| 83 | + class InnerModel(pydantic.BaseModel): |
| 84 | + value: int |
| 85 | + label: str |
| 86 | + |
| 87 | + class TestSignature(dspy.Signature): |
| 88 | + question: str = dspy.InputField() |
| 89 | + result: InnerModel = dspy.OutputField() |
| 90 | + |
| 91 | + adapter = XMLAdapter() |
| 92 | + # Format output fields as XML |
| 93 | + fields_with_values = { |
| 94 | + FieldInfoWithName(name="result", info=TestSignature.output_fields["result"]): InnerModel(value=5, label="foo") |
| 95 | + } |
| 96 | + xml = adapter.format_field_with_value(fields_with_values) |
| 97 | + # The output will be a JSON string inside the XML tag |
| 98 | + assert xml.strip().startswith("<result>") |
| 99 | + assert '"value": 5' in xml |
| 100 | + assert '"label": "foo"' in xml |
| 101 | + assert xml.strip().endswith("</result>") |
| 102 | + |
| 103 | + # Parse XML output (should parse as string, not as model) |
| 104 | + completion = '<result>{"value": 5, "label": "foo"}</result>' |
| 105 | + parsed = adapter.parse(TestSignature, completion) |
| 106 | + # The parse_value helper will try to cast to InnerModel |
| 107 | + assert isinstance(parsed["result"], InnerModel) |
| 108 | + assert parsed["result"].value == 5 |
| 109 | + assert parsed["result"].label == "foo" |
| 110 | + |
| 111 | + |
| 112 | +def test_xml_adapter_format_and_parse_list_of_models(): |
| 113 | + class Item(pydantic.BaseModel): |
| 114 | + name: str |
| 115 | + score: float |
| 116 | + |
| 117 | + class TestSignature(dspy.Signature): |
| 118 | + items: list[Item] = dspy.OutputField() |
| 119 | + |
| 120 | + adapter = XMLAdapter() |
| 121 | + items = [Item(name="a", score=1.1), Item(name="b", score=2.2)] |
| 122 | + fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): items} |
| 123 | + xml = adapter.format_field_with_value(fields_with_values) |
| 124 | + assert xml.strip().startswith("<items>") |
| 125 | + assert '"name": "a"' in xml |
| 126 | + assert '"score": 2.2' in xml |
| 127 | + assert xml.strip().endswith("</items>") |
| 128 | + |
| 129 | + # Parse XML output |
| 130 | + import json |
| 131 | + |
| 132 | + completion = f"<items>{json.dumps([i.model_dump() for i in items])}</items>" |
| 133 | + parsed = adapter.parse(TestSignature, completion) |
| 134 | + assert isinstance(parsed["items"], list) |
| 135 | + assert all(isinstance(i, Item) for i in parsed["items"]) |
| 136 | + assert parsed["items"][0].name == "a" |
| 137 | + assert parsed["items"][1].score == 2.2 |
| 138 | + |
| 139 | + |
| 140 | +def test_xml_adapter_with_tool_like_output(): |
| 141 | + # XMLAdapter does not natively support tool calls, but we can test structured output |
| 142 | + class ToolCall(pydantic.BaseModel): |
| 143 | + name: str |
| 144 | + args: dict |
| 145 | + result: str |
| 146 | + |
| 147 | + class TestSignature(dspy.Signature): |
| 148 | + question: str = dspy.InputField() |
| 149 | + tool_calls: list[ToolCall] = dspy.OutputField() |
| 150 | + answer: str = dspy.OutputField() |
| 151 | + |
| 152 | + adapter = XMLAdapter() |
| 153 | + tool_calls = [ |
| 154 | + ToolCall(name="get_weather", args={"city": "Tokyo"}, result="Sunny"), |
| 155 | + ToolCall(name="get_population", args={"country": "Japan", "year": 2023}, result="125M"), |
| 156 | + ] |
| 157 | + fields_with_values = { |
| 158 | + FieldInfoWithName(name="tool_calls", info=TestSignature.output_fields["tool_calls"]): tool_calls, |
| 159 | + FieldInfoWithName( |
| 160 | + name="answer", info=TestSignature.output_fields["answer"] |
| 161 | + ): "The weather is Sunny. Population is 125M.", |
| 162 | + } |
| 163 | + xml = adapter.format_field_with_value(fields_with_values) |
| 164 | + assert xml.strip().startswith("<tool_calls>") |
| 165 | + assert '"name": "get_weather"' in xml |
| 166 | + assert '"result": "125M"' in xml |
| 167 | + assert xml.strip().endswith("</answer>") |
| 168 | + |
| 169 | + import json |
| 170 | + |
| 171 | + completion = ( |
| 172 | + f"<tool_calls>{json.dumps([tc.model_dump() for tc in tool_calls])}</tool_calls>" |
| 173 | + f"\n<answer>The weather is Sunny. Population is 125M.</answer>" |
| 174 | + ) |
| 175 | + parsed = adapter.parse(TestSignature, completion) |
| 176 | + assert isinstance(parsed["tool_calls"], list) |
| 177 | + assert parsed["tool_calls"][0].name == "get_weather" |
| 178 | + assert parsed["tool_calls"][1].result == "125M" |
| 179 | + assert parsed["answer"] == "The weather is Sunny. Population is 125M." |
| 180 | + |
| 181 | + |
| 182 | +def test_xml_adapter_formats_nested_images(): |
| 183 | + class ImageWrapper(pydantic.BaseModel): |
| 184 | + images: list[dspy.Image] |
| 185 | + tag: list[str] |
| 186 | + |
| 187 | + class MySignature(dspy.Signature): |
| 188 | + image: ImageWrapper = dspy.InputField() |
| 189 | + text: str = dspy.OutputField() |
| 190 | + |
| 191 | + image1 = dspy.Image(url="https://example.com/image1.jpg") |
| 192 | + image2 = dspy.Image(url="https://example.com/image2.jpg") |
| 193 | + image3 = dspy.Image(url="https://example.com/image3.jpg") |
| 194 | + |
| 195 | + image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"]) |
| 196 | + demos = [ |
| 197 | + dspy.Example( |
| 198 | + image=image_wrapper, |
| 199 | + text="This is a test image", |
| 200 | + ), |
| 201 | + ] |
| 202 | + |
| 203 | + image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"]) |
| 204 | + adapter = dspy.XMLAdapter() |
| 205 | + messages = adapter.format(MySignature, demos, {"image": image_wrapper_2}) |
| 206 | + |
| 207 | + assert len(messages) == 4 |
| 208 | + |
| 209 | + # Image information in the few-shot example's user message |
| 210 | + expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} |
| 211 | + expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} |
| 212 | + expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} |
| 213 | + assert expected_image1_content in messages[1]["content"] |
| 214 | + assert expected_image2_content in messages[1]["content"] |
| 215 | + assert expected_image3_content in messages[1]["content"] |
| 216 | + |
| 217 | + # The query image is formatted in the last user message |
| 218 | + assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"] |
0 commit comments