Skip to content

Commit 716e82c

Browse files
Add xml adapter (#8358)
* Add XML adapter and improve enum prompts * add unit testing * format * add __init__ import --------- Co-authored-by: Omar Khattab <okhat@users.noreply.github.com>
1 parent 47f3b49 commit 716e82c

File tree

5 files changed

+285
-2
lines changed

5 files changed

+285
-2
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dspy.evaluate import Evaluate # isort: skip
1010
from dspy.clients import * # isort: skip
11-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip
11+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
55
from dspy.adapters.types import Audio, BaseType, History, Image, Tool, ToolCalls
6+
from dspy.adapters.xml_adapter import XMLAdapter
67

78
__all__ = [
89
"Adapter",
@@ -12,6 +13,7 @@
1213
"Image",
1314
"Audio",
1415
"JSONAdapter",
16+
"XMLAdapter",
1517
"TwoStepAdapter",
1618
"Tool",
1719
"ToolCalls",

dspy/adapters/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def translate_field_type(field_name, field_info):
9090
elif field_type in (int, float):
9191
desc = f"must be a single {field_type.__name__} value"
9292
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
93-
desc = f"must be one of: {'; '.join(field_type.__members__)}"
93+
enum_vals = '; '.join(str(member.value) for member in field_type)
94+
desc = f"must be one of: {enum_vals}"
9495
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
9596
desc = (
9697
# Strongly encourage the LM to avoid choosing values that don't appear in the

dspy/adapters/xml_adapter.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import re
2+
from typing import Any, Dict, Optional, Type
3+
4+
from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
5+
from dspy.adapters.utils import format_field_value
6+
from dspy.signatures.signature import Signature
7+
from dspy.utils.callback import BaseCallback
8+
9+
10+
class XMLAdapter(ChatAdapter):
11+
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
12+
super().__init__(callbacks)
13+
self.field_pattern = re.compile(r"<(?P<name>\w+)>((?P<content>.*?))</\1>", re.DOTALL)
14+
15+
def format_field_with_value(self, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
16+
output = []
17+
for field, field_value in fields_with_values.items():
18+
formatted = format_field_value(field_info=field.info, value=field_value)
19+
output.append(f"<{field.name}>\n{formatted}\n</{field.name}>")
20+
return "\n\n".join(output).strip()
21+
22+
def user_message_output_requirements(self, signature: Type[Signature]) -> str:
23+
message = "Respond with the corresponding output fields wrapped in XML tags"
24+
message += ", then ".join(f"`<{f}>`" for f in signature.output_fields)
25+
message += ", and then end with the `<completed>` tag."
26+
return message
27+
28+
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
29+
fields = {}
30+
for match in self.field_pattern.finditer(completion):
31+
name = match.group("name")
32+
content = match.group("content").strip()
33+
if name in signature.output_fields and name not in fields:
34+
fields[name] = content
35+
# Cast values using base class parse_value helper
36+
for k, v in fields.items():
37+
fields[k] = self._parse_field_value(signature.output_fields[k], v, completion, signature)
38+
if fields.keys() != signature.output_fields.keys():
39+
from dspy.utils.exceptions import AdapterParseError
40+
41+
raise AdapterParseError(
42+
adapter_name="XMLAdapter",
43+
signature=signature,
44+
lm_response=completion,
45+
parsed_result=fields,
46+
)
47+
return fields
48+
49+
def _parse_field_value(self, field_info, raw, completion, signature):
50+
from dspy.adapters.utils import parse_value
51+
52+
try:
53+
return parse_value(raw, field_info.annotation)
54+
except Exception as e:
55+
from dspy.utils.exceptions import AdapterParseError
56+
57+
raise AdapterParseError(
58+
adapter_name="XMLAdapter",
59+
signature=signature,
60+
lm_response=completion,
61+
message=f"Failed to parse field {field_info} with value {raw}: {e}",
62+
)

tests/adapters/test_xml_adapter.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import pydantic
2+
import pytest
3+
4+
import dspy
5+
from dspy.adapters.chat_adapter import FieldInfoWithName
6+
from dspy.adapters.xml_adapter import XMLAdapter
7+
8+
9+
def test_xml_adapter_format_and_parse_basic():
10+
class TestSignature(dspy.Signature):
11+
question: str = dspy.InputField()
12+
answer: str = dspy.OutputField()
13+
14+
adapter = XMLAdapter()
15+
# Format output fields as XML
16+
fields_with_values = {FieldInfoWithName(name="answer", info=TestSignature.output_fields["answer"]): "Paris"}
17+
xml = adapter.format_field_with_value(fields_with_values)
18+
assert xml.strip() == "<answer>\nParis\n</answer>"
19+
20+
# Parse XML output
21+
completion = "<answer>Paris</answer>"
22+
parsed = adapter.parse(TestSignature, completion)
23+
assert parsed == {"answer": "Paris"}
24+
25+
26+
def test_xml_adapter_parse_multiple_fields():
27+
class TestSignature(dspy.Signature):
28+
question: str = dspy.InputField()
29+
answer: str = dspy.OutputField()
30+
explanation: str = dspy.OutputField()
31+
32+
adapter = XMLAdapter()
33+
completion = """
34+
<answer>Paris</answer>
35+
<explanation>The capital of France is Paris.</explanation>
36+
"""
37+
parsed = adapter.parse(TestSignature, completion)
38+
assert parsed == {"answer": "Paris", "explanation": "The capital of France is Paris."}
39+
40+
41+
def test_xml_adapter_parse_raises_on_missing_field():
42+
class TestSignature(dspy.Signature):
43+
question: str = dspy.InputField()
44+
answer: str = dspy.OutputField()
45+
explanation: str = dspy.OutputField()
46+
47+
adapter = XMLAdapter()
48+
completion = "<answer>Paris</answer>"
49+
with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e:
50+
adapter.parse(TestSignature, completion)
51+
assert e.value.adapter_name == "XMLAdapter"
52+
assert e.value.signature == TestSignature
53+
assert e.value.lm_response == "<answer>Paris</answer>"
54+
assert "explanation" in str(e.value)
55+
56+
57+
def test_xml_adapter_parse_casts_types():
58+
class TestSignature(dspy.Signature):
59+
number: int = dspy.OutputField()
60+
flag: bool = dspy.OutputField()
61+
62+
adapter = XMLAdapter()
63+
completion = """
64+
<number>42</number>
65+
<flag>true</flag>
66+
"""
67+
parsed = adapter.parse(TestSignature, completion)
68+
assert parsed == {"number": 42, "flag": True}
69+
70+
71+
def test_xml_adapter_parse_raises_on_type_error():
72+
class TestSignature(dspy.Signature):
73+
number: int = dspy.OutputField()
74+
75+
adapter = XMLAdapter()
76+
completion = "<number>not_a_number</number>"
77+
with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e:
78+
adapter.parse(TestSignature, completion)
79+
assert "Failed to parse field" in str(e.value)
80+
81+
82+
def test_xml_adapter_format_and_parse_nested_model():
83+
class InnerModel(pydantic.BaseModel):
84+
value: int
85+
label: str
86+
87+
class TestSignature(dspy.Signature):
88+
question: str = dspy.InputField()
89+
result: InnerModel = dspy.OutputField()
90+
91+
adapter = XMLAdapter()
92+
# Format output fields as XML
93+
fields_with_values = {
94+
FieldInfoWithName(name="result", info=TestSignature.output_fields["result"]): InnerModel(value=5, label="foo")
95+
}
96+
xml = adapter.format_field_with_value(fields_with_values)
97+
# The output will be a JSON string inside the XML tag
98+
assert xml.strip().startswith("<result>")
99+
assert '"value": 5' in xml
100+
assert '"label": "foo"' in xml
101+
assert xml.strip().endswith("</result>")
102+
103+
# Parse XML output (should parse as string, not as model)
104+
completion = '<result>{"value": 5, "label": "foo"}</result>'
105+
parsed = adapter.parse(TestSignature, completion)
106+
# The parse_value helper will try to cast to InnerModel
107+
assert isinstance(parsed["result"], InnerModel)
108+
assert parsed["result"].value == 5
109+
assert parsed["result"].label == "foo"
110+
111+
112+
def test_xml_adapter_format_and_parse_list_of_models():
113+
class Item(pydantic.BaseModel):
114+
name: str
115+
score: float
116+
117+
class TestSignature(dspy.Signature):
118+
items: list[Item] = dspy.OutputField()
119+
120+
adapter = XMLAdapter()
121+
items = [Item(name="a", score=1.1), Item(name="b", score=2.2)]
122+
fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): items}
123+
xml = adapter.format_field_with_value(fields_with_values)
124+
assert xml.strip().startswith("<items>")
125+
assert '"name": "a"' in xml
126+
assert '"score": 2.2' in xml
127+
assert xml.strip().endswith("</items>")
128+
129+
# Parse XML output
130+
import json
131+
132+
completion = f"<items>{json.dumps([i.model_dump() for i in items])}</items>"
133+
parsed = adapter.parse(TestSignature, completion)
134+
assert isinstance(parsed["items"], list)
135+
assert all(isinstance(i, Item) for i in parsed["items"])
136+
assert parsed["items"][0].name == "a"
137+
assert parsed["items"][1].score == 2.2
138+
139+
140+
def test_xml_adapter_with_tool_like_output():
141+
# XMLAdapter does not natively support tool calls, but we can test structured output
142+
class ToolCall(pydantic.BaseModel):
143+
name: str
144+
args: dict
145+
result: str
146+
147+
class TestSignature(dspy.Signature):
148+
question: str = dspy.InputField()
149+
tool_calls: list[ToolCall] = dspy.OutputField()
150+
answer: str = dspy.OutputField()
151+
152+
adapter = XMLAdapter()
153+
tool_calls = [
154+
ToolCall(name="get_weather", args={"city": "Tokyo"}, result="Sunny"),
155+
ToolCall(name="get_population", args={"country": "Japan", "year": 2023}, result="125M"),
156+
]
157+
fields_with_values = {
158+
FieldInfoWithName(name="tool_calls", info=TestSignature.output_fields["tool_calls"]): tool_calls,
159+
FieldInfoWithName(
160+
name="answer", info=TestSignature.output_fields["answer"]
161+
): "The weather is Sunny. Population is 125M.",
162+
}
163+
xml = adapter.format_field_with_value(fields_with_values)
164+
assert xml.strip().startswith("<tool_calls>")
165+
assert '"name": "get_weather"' in xml
166+
assert '"result": "125M"' in xml
167+
assert xml.strip().endswith("</answer>")
168+
169+
import json
170+
171+
completion = (
172+
f"<tool_calls>{json.dumps([tc.model_dump() for tc in tool_calls])}</tool_calls>"
173+
f"\n<answer>The weather is Sunny. Population is 125M.</answer>"
174+
)
175+
parsed = adapter.parse(TestSignature, completion)
176+
assert isinstance(parsed["tool_calls"], list)
177+
assert parsed["tool_calls"][0].name == "get_weather"
178+
assert parsed["tool_calls"][1].result == "125M"
179+
assert parsed["answer"] == "The weather is Sunny. Population is 125M."
180+
181+
182+
def test_xml_adapter_formats_nested_images():
183+
class ImageWrapper(pydantic.BaseModel):
184+
images: list[dspy.Image]
185+
tag: list[str]
186+
187+
class MySignature(dspy.Signature):
188+
image: ImageWrapper = dspy.InputField()
189+
text: str = dspy.OutputField()
190+
191+
image1 = dspy.Image(url="https://example.com/image1.jpg")
192+
image2 = dspy.Image(url="https://example.com/image2.jpg")
193+
image3 = dspy.Image(url="https://example.com/image3.jpg")
194+
195+
image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"])
196+
demos = [
197+
dspy.Example(
198+
image=image_wrapper,
199+
text="This is a test image",
200+
),
201+
]
202+
203+
image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"])
204+
adapter = dspy.XMLAdapter()
205+
messages = adapter.format(MySignature, demos, {"image": image_wrapper_2})
206+
207+
assert len(messages) == 4
208+
209+
# Image information in the few-shot example's user message
210+
expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
211+
expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
212+
expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
213+
assert expected_image1_content in messages[1]["content"]
214+
assert expected_image2_content in messages[1]["content"]
215+
assert expected_image3_content in messages[1]["content"]
216+
217+
# The query image is formatted in the last user message
218+
assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]

0 commit comments

Comments
 (0)