diff --git a/dspy/adapters/xml_adapter.py b/dspy/adapters/xml_adapter.py index b75dfe90bd..065383f3ff 100644 --- a/dspy/adapters/xml_adapter.py +++ b/dspy/adapters/xml_adapter.py @@ -1,62 +1,285 @@ -import re -from typing import Any, Dict, Type +import inspect +import pydantic +import xml.etree.ElementTree as ET +from typing import Any, Dict, Type, get_origin, get_args +from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName -from dspy.adapters.utils import format_field_value +from dspy.adapters.utils import translate_field_type from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback +from dspy.primitives.prediction import Prediction class XMLAdapter(ChatAdapter): - def __init__(self, callbacks: list[BaseCallback] | None = None): + def __init__(self, callbacks: list[BaseCallback] | None = None, ): super().__init__(callbacks) - self.field_pattern = re.compile(r"<(?P\w+)>((?P.*?))", re.DOTALL) def format_field_with_value(self, fields_with_values: Dict[FieldInfoWithName, Any]) -> str: - output = [] - for field, field_value in fields_with_values.items(): - formatted = format_field_value(field_info=field.info, value=field_value) - output.append(f"<{field.name}>\n{formatted}\n") - return "\n\n".join(output).strip() + return self._dict_to_xml( + {field.name: field_value for field, field_value in fields_with_values.items()}, + ) + + def format_field_structure(self, signature: Type[Signature]) -> str: + """ + Generate comprehensive instructions showing the XML format for both input and output fields. + This helps the language model understand the expected structure. + """ + parts = [] + parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") + + if signature.input_fields: + parts.append("Inputs will have the following structure:") + input_structure = self._generate_fields_xml_structure(signature.input_fields) + parts.append(input_structure) + + parts.append("Outputs will have the following structure:") + output_structure = self._generate_fields_xml_structure(signature.output_fields) + parts.append(output_structure) + + return "\n\n".join(parts).strip() def user_message_output_requirements(self, signature: Type[Signature]) -> str: - message = "Respond with the corresponding output fields wrapped in XML tags" - message += ", then ".join(f"`<{f}>`" for f in signature.output_fields) - message += ", and then end with the `` tag." - return message + """ + Generate a concise reminder of the expected XML output structure for the language model. + """ + if not signature.output_fields: + return "Respond with XML tags as specified." + + # Generate compact schema representation + schemas = [] + for field_name, field_info in signature.output_fields.items(): + schema = self._generate_compact_xml_schema(field_name, field_info.annotation) + schemas.append(schema) + + if len(schemas) == 1: + return f"Respond with XML in the following structure: {schemas[0]}" + else: + schema_list = ", ".join(schemas) + return f"Respond with XML containing the following structures: {schema_list}" def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: - fields = {} - for match in self.field_pattern.finditer(completion): - name = match.group("name") - content = match.group("content").strip() - if name in signature.output_fields and name not in fields: - fields[name] = content - # Cast values using base class parse_value helper - for k, v in fields.items(): - fields[k] = self._parse_field_value(signature.output_fields[k], v, completion, signature) - if fields.keys() != signature.output_fields.keys(): + if isinstance(completion, Prediction): + completion = completion.completion + try: + # Wrap completion in a root tag to handle multiple top-level elements + root = ET.fromstring(f"{completion}") + parsed_dict = self._xml_to_dict(root) + + # Create a dynamic Pydantic model for the output fields only + output_field_definitions = { + name: (field.annotation, field) for name, field in signature.output_fields.items() + } + OutputModel = pydantic.create_model( + f"{signature.__name__}Output", + **output_field_definitions, + ) + + # If there's a single output field, the LM might not wrap it in the field name. + if len(signature.output_fields) == 1: + field_name = next(iter(signature.output_fields)) + if field_name not in parsed_dict: + parsed_dict = {field_name: parsed_dict} + + # Pre-process the dictionary to handle empty list cases + for name, field in signature.output_fields.items(): + # Check if the field is a list type and the parsed value is an empty string + if ( + get_origin(field.annotation) is list + and name in parsed_dict + and parsed_dict[name] == "" + ): + parsed_dict[name] = [] + + # Validate the parsed dictionary against the dynamic output model + validated_data = OutputModel(**parsed_dict) + + # Return a dictionary of field names to values (which can be Pydantic models) + return {name: getattr(validated_data, name) for name in signature.output_fields} + + except ET.ParseError as e: from dspy.utils.exceptions import AdapterParseError raise AdapterParseError( adapter_name="XMLAdapter", signature=signature, lm_response=completion, - parsed_result=fields, - ) - return fields - - def _parse_field_value(self, field_info, raw, completion, signature): - from dspy.adapters.utils import parse_value - - try: - return parse_value(raw, field_info.annotation) - except Exception as e: + message=f"Failed to parse XML: {e}", + ) from e + except pydantic.ValidationError as e: from dspy.utils.exceptions import AdapterParseError raise AdapterParseError( adapter_name="XMLAdapter", signature=signature, lm_response=completion, - message=f"Failed to parse field {field_info} with value {raw}: {e}", - ) + parsed_result=parsed_dict, + message=f"Pydantic validation failed: {e}", + ) from e + + def _generate_fields_xml_structure(self, fields: Dict[str, FieldInfo]) -> str: + """Generate XML structure representation for a collection of fields.""" + if not fields: + return "" + + structures = [] + for field_name, field_info in fields.items(): + structure = self._generate_xml_schema_structure(field_name, field_info.annotation) + structures.append(structure) + + return "\n".join(structures) + + def _generate_xml_schema_structure(self, field_name: str, field_annotation: Type, indent: int = 0) -> str: + """ + Generate XML schema structure for a field, handling nested models recursively. + Returns properly indented XML showing the expected structure. + """ + indent_str = " " * indent + + # Handle Pydantic models by showing their nested structure + if (inspect.isclass(field_annotation) and + issubclass(field_annotation, pydantic.BaseModel) and + hasattr(field_annotation, 'model_fields')): + + lines = [f"{indent_str}<{field_name}>"] + for sub_field_name, sub_field_info in field_annotation.model_fields.items(): + sub_structure = self._generate_xml_schema_structure( + sub_field_name, sub_field_info.annotation, indent + 1 + ) + lines.append(sub_structure) + lines.append(f"{indent_str}") + return "\n".join(lines) + + # Handle lists by showing repeated elements + elif get_origin(field_annotation) is list: + args = get_args(field_annotation) + if args: + item_type = args[0] + if (inspect.isclass(item_type) and + issubclass(item_type, pydantic.BaseModel) and + hasattr(item_type, 'model_fields')): + # Show nested structure for Pydantic models in lists + example = self._generate_xml_schema_structure(field_name, item_type, indent) + return f"{example}\n{example}" + else: + # Show simple repeated elements + placeholder = self._get_type_placeholder(item_type) + return f"{indent_str}<{field_name}>{placeholder}\n{indent_str}<{field_name}>{placeholder}" + else: + return f"{indent_str}<{field_name}>..." + + # Handle simple types with type-appropriate placeholders + else: + placeholder = self._get_type_placeholder_with_hint(field_annotation, field_name) + return f"{indent_str}<{field_name}>{placeholder}" + + def _get_type_placeholder_with_hint(self, type_annotation: Type, field_name: str) -> str: + """Get a placeholder value with type hint for a field.""" + if type_annotation is str: + return f"{{{field_name}}}" + elif type_annotation is int: + return f"{{{field_name}}} # must be a single int value" + elif type_annotation is float: + return f"{{{field_name}}} # must be a single float value" + elif type_annotation is bool: + return f"{{{field_name}}} # must be True or False" + else: + return f"{{{field_name}}}" + + def _generate_compact_xml_schema(self, field_name: str, field_annotation: Type) -> str: + """ + Generate a compact XML schema representation for user_message_output_requirements. + Returns a condensed format like: ...... + """ + # Handle Pydantic models + if (inspect.isclass(field_annotation) and + issubclass(field_annotation, pydantic.BaseModel) and + hasattr(field_annotation, 'model_fields')): + + inner_elements = [] + for sub_field_name, sub_field_info in field_annotation.model_fields.items(): + sub_schema = self._generate_compact_xml_schema(sub_field_name, sub_field_info.annotation) + inner_elements.append(sub_schema) + + inner_content = "".join(inner_elements) + return f"<{field_name}>{inner_content}" + + # Handle lists + elif get_origin(field_annotation) is list: + args = get_args(field_annotation) + if args: + item_type = args[0] + item_schema = self._generate_compact_xml_schema(field_name, item_type) + return item_schema # Lists are represented by repeated elements + else: + return f"<{field_name}>..." + + # Handle simple types + else: + return f"<{field_name}>..." + + def _get_type_placeholder(self, type_annotation: Type) -> str: + """Get a simple placeholder value for a type.""" + if type_annotation is str: + return "..." + elif type_annotation is int: + return "0" + elif type_annotation is float: + return "0.0" + elif type_annotation is bool: + return "true" + else: + return "..." + + def _dict_to_xml(self, data: Any, root_tag: str = "output") -> str: + def _recursive_serializer(obj): + if isinstance(obj, pydantic.BaseModel): + if hasattr(obj, 'model_dump'): + return obj.model_dump() + return obj.dict() # Fallback for Pydantic v1 + if isinstance(obj, dict): + return {k: _recursive_serializer(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_recursive_serializer(i) for i in obj] + return obj + + data = _recursive_serializer(data) + + def build_element(parent, tag, content): + if isinstance(content, dict): + element = ET.SubElement(parent, tag) + for key, val in content.items(): + build_element(element, key, val) + elif isinstance(content, list): + if not content: # Handle empty list + ET.SubElement(parent, tag) + for item in content: + build_element(parent, tag, item) + else: + element = ET.SubElement(parent, tag) + element.text = str(content) + + root = ET.Element(root_tag) + if isinstance(data, dict): + for key, val in data.items(): + build_element(root, key, val) + else: + root.text = str(data) + + inner_xml = "".join(ET.tostring(e, encoding="unicode") for e in root) + return inner_xml + + def _xml_to_dict(self, element: ET.Element) -> Any: + if not list(element): + return element.text or "" + + d = {} + for child in element: + child_data = self._xml_to_dict(child) + if child.tag in d: + if not isinstance(d[child.tag], list): + d[child.tag] = [d[child.tag]] + d[child.tag].append(child_data) + else: + d[child.tag] = child_data + return d diff --git a/dspy/utils/pydantic_utils.py b/dspy/utils/pydantic_utils.py new file mode 100644 index 0000000000..6838c009d6 --- /dev/null +++ b/dspy/utils/pydantic_utils.py @@ -0,0 +1,11 @@ +import pydantic + + +def get_pydantic_object_serializer(): + # Pydantic V2 has a more robust JSON encoder, but we need to handle V1 as well. + if hasattr(pydantic, "__version__") and pydantic.__version__.startswith("2."): + from pydantic.v1.json import pydantic_encoder + return pydantic_encoder + else: + from pydantic.json import pydantic_encoder + return pydantic_encoder diff --git a/tests/adapters/test_xml_adapter.py b/tests/adapters/test_xml_adapter.py index f47884d3f3..7d4a9e0088 100644 --- a/tests/adapters/test_xml_adapter.py +++ b/tests/adapters/test_xml_adapter.py @@ -1,6 +1,5 @@ import pydantic import pytest - import dspy from dspy.adapters.chat_adapter import FieldInfoWithName from dspy.adapters.xml_adapter import XMLAdapter @@ -15,7 +14,7 @@ class TestSignature(dspy.Signature): # Format output fields as XML fields_with_values = {FieldInfoWithName(name="answer", info=TestSignature.output_fields["answer"]): "Paris"} xml = adapter.format_field_with_value(fields_with_values) - assert xml.strip() == "\nParis\n" + assert xml.strip() == "Paris" # Parse XML output completion = "Paris" @@ -31,9 +30,7 @@ class TestSignature(dspy.Signature): adapter = XMLAdapter() completion = """ -Paris -The capital of France is Paris. -""" +ParisThe capital of France is Paris.""" parsed = adapter.parse(TestSignature, completion) assert parsed == {"answer": "Paris", "explanation": "The capital of France is Paris."} @@ -46,12 +43,8 @@ class TestSignature(dspy.Signature): adapter = XMLAdapter() completion = "Paris" - with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: + with pytest.raises(dspy.utils.exceptions.AdapterParseError): adapter.parse(TestSignature, completion) - assert e.value.adapter_name == "XMLAdapter" - assert e.value.signature == TestSignature - assert e.value.lm_response == "Paris" - assert "explanation" in str(e.value) def test_xml_adapter_parse_casts_types(): @@ -61,9 +54,7 @@ class TestSignature(dspy.Signature): adapter = XMLAdapter() completion = """ -42 -true -""" +42true""" parsed = adapter.parse(TestSignature, completion) assert parsed == {"number": 42, "flag": True} @@ -74,42 +65,51 @@ class TestSignature(dspy.Signature): adapter = XMLAdapter() completion = "not_a_number" - with pytest.raises(dspy.utils.exceptions.AdapterParseError) as e: + with pytest.raises(dspy.utils.exceptions.AdapterParseError): adapter.parse(TestSignature, completion) - assert "Failed to parse field" in str(e.value) -def test_xml_adapter_format_and_parse_nested_model(): +def test_xml_adapter_handles_true_nested_xml_parsing(): class InnerModel(pydantic.BaseModel): value: int label: str class TestSignature(dspy.Signature): - question: str = dspy.InputField() result: InnerModel = dspy.OutputField() adapter = XMLAdapter() - # Format output fields as XML - fields_with_values = { - FieldInfoWithName(name="result", info=TestSignature.output_fields["result"]): InnerModel(value=5, label="foo") - } - xml = adapter.format_field_with_value(fields_with_values) - # The output will be a JSON string inside the XML tag - assert xml.strip().startswith("") - assert '"value": 5' in xml - assert '"label": "foo"' in xml - assert xml.strip().endswith("") - - # Parse XML output (should parse as string, not as model) - completion = '{"value": 5, "label": "foo"}' + completion = """ + + 5 + + +""" parsed = adapter.parse(TestSignature, completion) - # The parse_value helper will try to cast to InnerModel assert isinstance(parsed["result"], InnerModel) assert parsed["result"].value == 5 assert parsed["result"].label == "foo" -def test_xml_adapter_format_and_parse_list_of_models(): +def test_xml_adapter_formats_true_nested_xml(): + class InnerModel(pydantic.BaseModel): + value: int + label: str + + class TestSignature(dspy.Signature): + result: InnerModel = dspy.OutputField() + + adapter = XMLAdapter() + fields_with_values = { + FieldInfoWithName(name="result", info=TestSignature.output_fields["result"]): InnerModel(value=5, label="foo") + } + xml = adapter.format_field_with_value(fields_with_values) + + # The output should be a true nested XML string + expected_xml = "5" + assert xml.strip() == expected_xml.strip() + + +def test_xml_adapter_handles_lists_as_repeated_tags(): class Item(pydantic.BaseModel): name: str score: float @@ -118,101 +118,108 @@ class TestSignature(dspy.Signature): items: list[Item] = dspy.OutputField() adapter = XMLAdapter() - items = [Item(name="a", score=1.1), Item(name="b", score=2.2)] - fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): items} - xml = adapter.format_field_with_value(fields_with_values) - assert xml.strip().startswith("") - assert '"name": "a"' in xml - assert '"score": 2.2' in xml - assert xml.strip().endswith("") - - # Parse XML output - import json - - completion = f"{json.dumps([i.model_dump() for i in items])}" + + # Test parsing repeated tags into a list + completion = """ + + a + 1.1 + + + b + 2.2 + +""" parsed = adapter.parse(TestSignature, completion) assert isinstance(parsed["items"], list) + assert len(parsed["items"]) == 2 assert all(isinstance(i, Item) for i in parsed["items"]) assert parsed["items"][0].name == "a" assert parsed["items"][1].score == 2.2 + # Test formatting a list into repeated tags + items = [Item(name="x", score=3.3), Item(name="y", score=4.4)] + fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): items} + xml = adapter.format_field_with_value(fields_with_values) + + expected_xml = "x3.3y4.4" + assert xml.strip() == expected_xml.strip() -def test_xml_adapter_with_tool_like_output(): - # XMLAdapter does not natively support tool calls, but we can test structured output - class ToolCall(pydantic.BaseModel): - name: str - args: dict - result: str +def test_parse_malformed_xml(): class TestSignature(dspy.Signature): - question: str = dspy.InputField() - tool_calls: list[ToolCall] = dspy.OutputField() - answer: str = dspy.OutputField() + data: str = dspy.OutputField() adapter = XMLAdapter() - tool_calls = [ - ToolCall(name="get_weather", args={"city": "Tokyo"}, result="Sunny"), - ToolCall(name="get_population", args={"country": "Japan", "year": 2023}, result="125M"), - ] - fields_with_values = { - FieldInfoWithName(name="tool_calls", info=TestSignature.output_fields["tool_calls"]): tool_calls, - FieldInfoWithName( - name="answer", info=TestSignature.output_fields["answer"] - ): "The weather is Sunny. Population is 125M.", - } + completion = "text" + with pytest.raises(dspy.utils.exceptions.AdapterParseError): + adapter.parse(TestSignature, completion) + + +def test_format_and_parse_deeply_nested_model(): + class Inner(pydantic.BaseModel): + text: str + + class Middle(pydantic.BaseModel): + inner: Inner + num: int + + class TestSignature(dspy.Signature): + middle: Middle = dspy.OutputField() + + adapter = XMLAdapter() + data = Middle(inner=Inner(text="deep"), num=123) + fields_with_values = {FieldInfoWithName(name="middle", info=TestSignature.output_fields["middle"]): data} + + # Test formatting xml = adapter.format_field_with_value(fields_with_values) - assert xml.strip().startswith("") - assert '"name": "get_weather"' in xml - assert '"result": "125M"' in xml - assert xml.strip().endswith("") + expected_xml = "deep123" + assert xml.strip() == expected_xml - import json + # Test parsing + parsed = adapter.parse(TestSignature, xml) + assert isinstance(parsed["middle"], Middle) + assert parsed["middle"].inner.text == "deep" + assert parsed["middle"].num == 123 - completion = ( - f"{json.dumps([tc.model_dump() for tc in tool_calls])}" - f"\nThe weather is Sunny. Population is 125M." - ) - parsed = adapter.parse(TestSignature, completion) - assert isinstance(parsed["tool_calls"], list) - assert parsed["tool_calls"][0].name == "get_weather" - assert parsed["tool_calls"][1].result == "125M" - assert parsed["answer"] == "The weather is Sunny. Population is 125M." - - -def test_xml_adapter_formats_nested_images(): - class ImageWrapper(pydantic.BaseModel): - images: list[dspy.Image] - tag: list[str] - - class MySignature(dspy.Signature): - image: ImageWrapper = dspy.InputField() - text: str = dspy.OutputField() - - image1 = dspy.Image(url="https://example.com/image1.jpg") - image2 = dspy.Image(url="https://example.com/image2.jpg") - image3 = dspy.Image(url="https://example.com/image3.jpg") - - image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"]) - demos = [ - dspy.Example( - image=image_wrapper, - text="This is a test image", - ), - ] - - image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"]) - adapter = dspy.XMLAdapter() - messages = adapter.format(MySignature, demos, {"image": image_wrapper_2}) - - assert len(messages) == 4 - - # Image information in the few-shot example's user message - expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} - expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} - expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} - assert expected_image1_content in messages[1]["content"] - assert expected_image2_content in messages[1]["content"] - assert expected_image3_content in messages[1]["content"] - - # The query image is formatted in the last user message - assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"] + +def test_format_and_parse_empty_list(): + class TestSignature(dspy.Signature): + items: list[str] = dspy.OutputField() + + adapter = XMLAdapter() + + # Test formatting + fields_with_values = {FieldInfoWithName(name="items", info=TestSignature.output_fields["items"]): []} + xml = adapter.format_field_with_value(fields_with_values) + assert xml.strip() in ["", ""] + + # Test parsing + parsed = adapter.parse(TestSignature, xml) + assert parsed["items"] == [] + + +def test_end_to_end_with_predict(): + class TestSignature(dspy.Signature): + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + # Mock LM + class MockLM(dspy.LM): + def __init__(self): + self.history = [] + self.kwargs = {} + + def __call__(self, messages, **kwargs): + self.history.append(messages) + completion = "mocked answer" + return [completion] + + lm = MockLM() + lm.model = "mock-model" + dspy.settings.configure(lm=lm, adapter=XMLAdapter()) + + predict = dspy.Predict(TestSignature) + result = predict(question="test question") + + assert result.answer == "mocked answer" \ No newline at end of file