Skip to content

Fix coding style in dspy/adapters #8155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.types import Image, History
from dspy.adapters.two_step_adapter import TwoStepAdapter
from dspy.adapters.types import History, Image

__all__ = [
"Adapter",
"ChatAdapter",
"JSONAdapter",
"Image",
"History",
"Image",
"JSONAdapter",
"TwoStepAdapter",
]
2 changes: 1 addition & 1 deletion dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def format_assistant_message_content(
self,
signature: Type[Signature],
outputs: dict[str, Any],
missing_field_message: str = None,
missing_field_message: Optional[str] = None,
) -> str:
"""Format the assistant message content.

Expand Down
4 changes: 2 additions & 2 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,6 @@ def format_finetune_data(
assistant_message_content = self.format_assistant_message_content( # returns a string, without the role
signature=signature, outputs=outputs
)
assistant_message = dict(role="assistant", content=assistant_message_content)
assistant_message = {"role": "assistant", "content": assistant_message_content}
messages = system_user_messages + [assistant_message]
return dict(messages=messages)
return {"messages": messages}
28 changes: 16 additions & 12 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import regex
import logging
from typing import Any, Dict, Type, get_origin

import json_repair
import litellm
import pydantic
import regex
from pydantic.fields import FieldInfo

from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
Expand All @@ -29,7 +29,7 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool:
such as dict[str, Any]. Structured Outputs require explicit properties, so such fields
are incompatible.
"""
for name, field in signature.output_fields.items():
for field in signature.output_fields.values():
annotation = field.annotation
if get_origin(annotation) is dict:
return True
Expand Down Expand Up @@ -121,9 +121,9 @@ def format_assistant_message_content(
return self.format_field_with_value(fields_with_values, role="assistant")

def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
pattern = r'\{(?:[^{}]|(?R))*\}'
match = regex.search(pattern, completion, regex.DOTALL)
if match:
pattern = r"\{(?:[^{}]|(?R))*\}"
match = regex.search(pattern, completion, regex.DOTALL)
if match:
completion = match.group(0)
fields = json_repair.loads(completion)

Expand Down Expand Up @@ -196,10 +196,14 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
fields[name] = (annotation, default)

# Build the model with extra fields forbidden.
Model = pydantic.create_model("DSPyProgramOutputs", **fields, __config__=type("Config", (), {"extra": "forbid"}))
pydantic_model = pydantic.create_model(
"DSPyProgramOutputs",
**fields,
__config__=type("Config", (), {"extra": "forbid"}),
)

# Generate the initial schema.
schema = Model.model_json_schema()
schema = pydantic_model.model_json_schema()

# Remove any DSPy-specific metadata.
for prop in schema.get("properties", {}).values():
Expand All @@ -208,9 +212,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py
def enforce_required(schema_part: dict):
"""
Recursively ensure that:
- for any object schema, a "required" key is added with all property names (or [] if no properties)
- additionalProperties is set to False regardless of the previous value.
- the same enforcement is run for nested arrays and definitions.
- for any object schema, a "required" key is added with all property names (or [] if no properties)
- additionalProperties is set to False regardless of the previous value.
- the same enforcement is run for nested arrays and definitions.
"""
if schema_part.get("type") == "object":
props = schema_part.get("properties")
Expand All @@ -237,6 +241,6 @@ def enforce_required(schema_part: dict):
enforce_required(schema)

# Override the model's JSON schema generation to return our precomputed schema.
Model.model_json_schema = lambda *args, **kwargs: schema
pydantic_model.model_json_schema = lambda *args, **kwargs: schema

return Model
return pydantic_model
4 changes: 2 additions & 2 deletions dspy/adapters/two_step_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Any, Optional, Type

from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
Expand Down Expand Up @@ -175,7 +175,7 @@ def format_assistant_message_content(
self,
signature: Type[Signature],
outputs: dict[str, Any],
missing_field_message: str = None,
missing_field_message: Optional[str] = None,
) -> str:
parts = []

Expand Down
49 changes: 26 additions & 23 deletions dspy/adapters/types/image.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import base64
import io
import mimetypes
import os
import re
from typing import Any, Dict, List, Union
from urllib.parse import urlparse
import re
import mimetypes

import pydantic
import requests
Expand All @@ -19,14 +19,14 @@

class Image(pydantic.BaseModel):
url: str

model_config = {
'frozen': True,
'str_strip_whitespace': True,
'validate_assignment': True,
'extra': 'forbid',
"frozen": True,
"str_strip_whitespace": True,
"validate_assignment": True,
"extra": "forbid",
}

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values):
Expand All @@ -52,7 +52,7 @@ def from_file(cls, file_path: str):
return cls(url=encode_image(file_path))

@classmethod
def from_PIL(cls, pil_image):
def from_PIL(cls, pil_image): # noqa: N802
return cls(url=encode_image(pil_image))

@pydantic.model_serializer()
Expand All @@ -66,9 +66,10 @@ def __repr__(self):
if "base64" in self.url:
len_base64 = len(self.url.split("base64,")[1])
image_type = self.url.split(";")[0].split("/")[-1]
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({len_base64!s})>)"
return f"Image(url='{self.url}')"


def is_url(string: str) -> bool:
"""Check if a string is a valid URL."""
try:
Expand Down Expand Up @@ -162,7 +163,7 @@ def _encode_image_from_url(image_url: str) -> str:
return f"data:{mime_type};base64,{encoded_data}"


def _encode_pil_image(image: 'PILImage') -> str:
def _encode_pil_image(image: "PILImage") -> str:
"""Encode a PIL Image object to a base64 data URI."""
buffered = io.BytesIO()
file_format = image.format or "PNG"
Expand Down Expand Up @@ -197,6 +198,7 @@ def is_image(obj) -> bool:
return True
return False


def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Try to expand image tags in the messages."""
for message in messages:
Expand All @@ -205,43 +207,44 @@ def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]
message["content"] = expand_image_tags(message["content"])
return messages


def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
"""Expand image tags in the text. If there are any image tags,
"""Expand image tags in the text. If there are any image tags,
turn it from a content string into a content list of texts and image urls.

Args:
text: The text content that may contain image tags

Returns:
Either the original string if no image tags, or a list of content dicts
with text and image_url entries
"""
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'

# If no image tags, return original text
if not re.search(image_tag_regex, text):
return text

final_list = []
remaining_text = text

while remaining_text:
match = re.search(image_tag_regex, remaining_text)
if not match:
if remaining_text.strip():
final_list.append({"type": "text", "text": remaining_text.strip()})
break

# Get text before the image tag
prefix = remaining_text[:match.start()].strip()
prefix = remaining_text[: match.start()].strip()
if prefix:
final_list.append({"type": "text", "text": prefix})

# Add the image
image_url = match.group(1)
final_list.append({"type": "image_url", "image_url": {"url": image_url}})

# Update remaining text
remaining_text = remaining_text[match.end():].strip()
remaining_text = remaining_text[match.end() :].strip()

return final_list
6 changes: 4 additions & 2 deletions dspy/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def parse_value(value, annotation):

if v in allowed:
return v

raise ValueError(f"{value!r} is not one of {allowed!r}")

if not isinstance(value, str):
Expand All @@ -174,6 +174,7 @@ def parse_value(value, annotation):
return str(candidate)
raise


def get_annotation_name(annotation):
origin = get_origin(annotation)
args = get_args(annotation)
Expand All @@ -193,6 +194,7 @@ def get_annotation_name(annotation):
args_str = ", ".join(get_annotation_name(a) for a in args)
return f"{get_annotation_name(origin)}[{args_str}]"


def get_field_description_string(fields: dict) -> str:
field_descriptions = []
for idx, (k, v) in enumerate(fields.items()):
Expand Down Expand Up @@ -220,7 +222,7 @@ def _format_input_list_field_value(value: List[Any]) -> str:
if len(value) == 1:
return _format_blob(value[0])

return "\n".join([f"[{idx+1}] {_format_blob(txt)}" for idx, txt in enumerate(value)])
return "\n".join([f"[{idx + 1}] {_format_blob(txt)}" for idx, txt in enumerate(value)])


def _format_blob(blob: str) -> str:
Expand Down