Skip to content

feat(adapter): rewrite XMLAdapter for nested-data support #8482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 107 additions & 35 deletions dspy/adapters/xml_adapter.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,134 @@
import re
from typing import Any, Dict, Type
import pydantic
import xml.etree.ElementTree as ET
from typing import Any, Dict, Type, get_origin

from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
from dspy.adapters.utils import format_field_value
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback
from dspy.primitives.prediction import Prediction


class XMLAdapter(ChatAdapter):
def __init__(self, callbacks: list[BaseCallback] | None = None):
super().__init__(callbacks)
self.field_pattern = re.compile(r"<(?P<name>\w+)>((?P<content>.*?))</\1>", re.DOTALL)

def format_field_with_value(self, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
output = []
for field, field_value in fields_with_values.items():
formatted = format_field_value(field_info=field.info, value=field_value)
output.append(f"<{field.name}>\n{formatted}\n</{field.name}>")
return "\n\n".join(output).strip()
return self._dict_to_xml(
{field.name: field_value for field, field_value in fields_with_values.items()},
)

def user_message_output_requirements(self, signature: Type[Signature]) -> str:
message = "Respond with the corresponding output fields wrapped in XML tags"
message += ", then ".join(f"`<{f}>`" for f in signature.output_fields)
message += ", and then end with the `<completed>` tag."
return message
# TODO: Add a more detailed message that describes the expected output structure.
return "Respond with the corresponding output fields wrapped in XML tags."

def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
fields = {}
for match in self.field_pattern.finditer(completion):
name = match.group("name")
content = match.group("content").strip()
if name in signature.output_fields and name not in fields:
fields[name] = content
# Cast values using base class parse_value helper
for k, v in fields.items():
fields[k] = self._parse_field_value(signature.output_fields[k], v, completion, signature)
if fields.keys() != signature.output_fields.keys():
if isinstance(completion, Prediction):
completion = completion.completion
try:
# Wrap completion in a root tag to handle multiple top-level elements
root = ET.fromstring(f"<root>{completion}</root>")
parsed_dict = self._xml_to_dict(root)

# Create a dynamic Pydantic model for the output fields only
output_field_definitions = {
name: (field.annotation, field) for name, field in signature.output_fields.items()
}
OutputModel = pydantic.create_model(
f"{signature.__name__}Output",
**output_field_definitions,
)

# If there's a single output field, the LM might not wrap it in the field name.
if len(signature.output_fields) == 1:
field_name = next(iter(signature.output_fields))
if field_name not in parsed_dict:
parsed_dict = {field_name: parsed_dict}

# Pre-process the dictionary to handle empty list cases
for name, field in signature.output_fields.items():
# Check if the field is a list type and the parsed value is an empty string
if (
get_origin(field.annotation) is list
and name in parsed_dict
and parsed_dict[name] == ""
):
parsed_dict[name] = []

# Validate the parsed dictionary against the dynamic output model
validated_data = OutputModel(**parsed_dict)

# Return a dictionary of field names to values (which can be Pydantic models)
return {name: getattr(validated_data, name) for name in signature.output_fields}

except ET.ParseError as e:
from dspy.utils.exceptions import AdapterParseError

raise AdapterParseError(
adapter_name="XMLAdapter",
signature=signature,
lm_response=completion,
parsed_result=fields,
)
return fields

def _parse_field_value(self, field_info, raw, completion, signature):
from dspy.adapters.utils import parse_value

try:
return parse_value(raw, field_info.annotation)
except Exception as e:
message=f"Failed to parse XML: {e}",
) from e
except pydantic.ValidationError as e:
from dspy.utils.exceptions import AdapterParseError

raise AdapterParseError(
adapter_name="XMLAdapter",
signature=signature,
lm_response=completion,
message=f"Failed to parse field {field_info} with value {raw}: {e}",
)
parsed_result=parsed_dict,
message=f"Pydantic validation failed: {e}",
) from e

def _dict_to_xml(self, data: Any, root_tag: str = "output") -> str:
def _recursive_serializer(obj):
if isinstance(obj, pydantic.BaseModel):
if hasattr(obj, 'model_dump'):
return obj.model_dump()
return obj.dict() # Fallback for Pydantic v1
if isinstance(obj, dict):
return {k: _recursive_serializer(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_recursive_serializer(i) for i in obj]
return obj

data = _recursive_serializer(data)

def build_element(parent, tag, content):
if isinstance(content, dict):
element = ET.SubElement(parent, tag)
for key, val in content.items():
build_element(element, key, val)
elif isinstance(content, list):
if not content: # Handle empty list
ET.SubElement(parent, tag)
for item in content:
build_element(parent, tag, item)
else:
element = ET.SubElement(parent, tag)
element.text = str(content)

root = ET.Element(root_tag)
if isinstance(data, dict):
for key, val in data.items():
build_element(root, key, val)
else:
root.text = str(data)

inner_xml = "".join(ET.tostring(e, encoding="unicode") for e in root)
return inner_xml

def _xml_to_dict(self, element: ET.Element) -> Any:
if not list(element):
return element.text or ""

d = {}
for child in element:
child_data = self._xml_to_dict(child)
if child.tag in d:
if not isinstance(d[child.tag], list):
d[child.tag] = [d[child.tag]]
d[child.tag].append(child_data)
else:
d[child.tag] = child_data
return d
11 changes: 11 additions & 0 deletions dspy/utils/pydantic_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pydantic


def get_pydantic_object_serializer():
# Pydantic V2 has a more robust JSON encoder, but we need to handle V1 as well.
if hasattr(pydantic, "__version__") and pydantic.__version__.startswith("2."):
from pydantic.v1.json import pydantic_encoder
return pydantic_encoder
else:
from pydantic.json import pydantic_encoder
return pydantic_encoder
Loading
Loading