diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index f6a128a7e7..2c40e7517f 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -1,14 +1,14 @@ from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter -from dspy.adapters.types import Image, History from dspy.adapters.two_step_adapter import TwoStepAdapter +from dspy.adapters.types import History, Image __all__ = [ "Adapter", "ChatAdapter", - "JSONAdapter", - "Image", "History", + "Image", + "JSONAdapter", "TwoStepAdapter", ] diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 0fa3b5d657..e2c37fe4ca 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -212,7 +212,7 @@ def format_assistant_message_content( self, signature: Type[Signature], outputs: dict[str, Any], - missing_field_message: str = None, + missing_field_message: Optional[str] = None, ) -> str: """Format the assistant message content. diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f2228b756f..b11e5c1ca9 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -216,6 +216,6 @@ def format_finetune_data( assistant_message_content = self.format_assistant_message_content( # returns a string, without the role signature=signature, outputs=outputs ) - assistant_message = dict(role="assistant", content=assistant_message_content) + assistant_message = {"role": "assistant", "content": assistant_message_content} messages = system_user_messages + [assistant_message] - return dict(messages=messages) + return {"messages": messages} diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 89d16c6ce7..605daf148c 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -1,11 +1,11 @@ import json -import regex import logging from typing import Any, Dict, Type, get_origin import json_repair import litellm import pydantic +import regex from pydantic.fields import FieldInfo from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName @@ -29,7 +29,7 @@ def _has_open_ended_mapping(signature: SignatureMeta) -> bool: such as dict[str, Any]. Structured Outputs require explicit properties, so such fields are incompatible. """ - for name, field in signature.output_fields.items(): + for field in signature.output_fields.values(): annotation = field.annotation if get_origin(annotation) is dict: return True @@ -121,9 +121,9 @@ def format_assistant_message_content( return self.format_field_with_value(fields_with_values, role="assistant") def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: - pattern = r'\{(?:[^{}]|(?R))*\}' - match = regex.search(pattern, completion, regex.DOTALL) - if match: + pattern = r"\{(?:[^{}]|(?R))*\}" + match = regex.search(pattern, completion, regex.DOTALL) + if match: completion = match.group(0) fields = json_repair.loads(completion) @@ -196,10 +196,14 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py fields[name] = (annotation, default) # Build the model with extra fields forbidden. - Model = pydantic.create_model("DSPyProgramOutputs", **fields, __config__=type("Config", (), {"extra": "forbid"})) + pydantic_model = pydantic.create_model( + "DSPyProgramOutputs", + **fields, + __config__=type("Config", (), {"extra": "forbid"}), + ) # Generate the initial schema. - schema = Model.model_json_schema() + schema = pydantic_model.model_json_schema() # Remove any DSPy-specific metadata. for prop in schema.get("properties", {}).values(): @@ -208,9 +212,9 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> type[py def enforce_required(schema_part: dict): """ Recursively ensure that: - - for any object schema, a "required" key is added with all property names (or [] if no properties) - - additionalProperties is set to False regardless of the previous value. - - the same enforcement is run for nested arrays and definitions. + - for any object schema, a "required" key is added with all property names (or [] if no properties) + - additionalProperties is set to False regardless of the previous value. + - the same enforcement is run for nested arrays and definitions. """ if schema_part.get("type") == "object": props = schema_part.get("properties") @@ -237,6 +241,6 @@ def enforce_required(schema_part: dict): enforce_required(schema) # Override the model's JSON schema generation to return our precomputed schema. - Model.model_json_schema = lambda *args, **kwargs: schema + pydantic_model.model_json_schema = lambda *args, **kwargs: schema - return Model + return pydantic_model diff --git a/dspy/adapters/two_step_adapter.py b/dspy/adapters/two_step_adapter.py index 02cb7750ea..a27a9de859 100644 --- a/dspy/adapters/two_step_adapter.py +++ b/dspy/adapters/two_step_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Type +from typing import Any, Optional, Type from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter @@ -175,7 +175,7 @@ def format_assistant_message_content( self, signature: Type[Signature], outputs: dict[str, Any], - missing_field_message: str = None, + missing_field_message: Optional[str] = None, ) -> str: parts = [] diff --git a/dspy/adapters/types/image.py b/dspy/adapters/types/image.py index 34341d459f..1ca1dc6c31 100644 --- a/dspy/adapters/types/image.py +++ b/dspy/adapters/types/image.py @@ -1,10 +1,10 @@ import base64 import io +import mimetypes import os +import re from typing import Any, Dict, List, Union from urllib.parse import urlparse -import re -import mimetypes import pydantic import requests @@ -19,14 +19,14 @@ class Image(pydantic.BaseModel): url: str - + model_config = { - 'frozen': True, - 'str_strip_whitespace': True, - 'validate_assignment': True, - 'extra': 'forbid', + "frozen": True, + "str_strip_whitespace": True, + "validate_assignment": True, + "extra": "forbid", } - + @pydantic.model_validator(mode="before") @classmethod def validate_input(cls, values): @@ -52,7 +52,7 @@ def from_file(cls, file_path: str): return cls(url=encode_image(file_path)) @classmethod - def from_PIL(cls, pil_image): + def from_PIL(cls, pil_image): # noqa: N802 return cls(url=encode_image(pil_image)) @pydantic.model_serializer() @@ -66,9 +66,10 @@ def __repr__(self): if "base64" in self.url: len_base64 = len(self.url.split("base64,")[1]) image_type = self.url.split(";")[0].split("/")[-1] - return f"Image(url=data:image/{image_type};base64,)" + return f"Image(url=data:image/{image_type};base64,)" return f"Image(url='{self.url}')" + def is_url(string: str) -> bool: """Check if a string is a valid URL.""" try: @@ -162,7 +163,7 @@ def _encode_image_from_url(image_url: str) -> str: return f"data:{mime_type};base64,{encoded_data}" -def _encode_pil_image(image: 'PILImage') -> str: +def _encode_pil_image(image: "PILImage") -> str: """Encode a PIL Image object to a base64 data URI.""" buffered = io.BytesIO() file_format = image.format or "PNG" @@ -197,6 +198,7 @@ def is_image(obj) -> bool: return True return False + def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Try to expand image tags in the messages.""" for message in messages: @@ -205,43 +207,44 @@ def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any] message["content"] = expand_image_tags(message["content"]) return messages + def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]: - """Expand image tags in the text. If there are any image tags, + """Expand image tags in the text. If there are any image tags, turn it from a content string into a content list of texts and image urls. - + Args: text: The text content that may contain image tags - + Returns: Either the original string if no image tags, or a list of content dicts with text and image_url entries """ image_tag_regex = r'"?(.*?)"?' - + # If no image tags, return original text if not re.search(image_tag_regex, text): return text - + final_list = [] remaining_text = text - + while remaining_text: match = re.search(image_tag_regex, remaining_text) if not match: if remaining_text.strip(): final_list.append({"type": "text", "text": remaining_text.strip()}) break - + # Get text before the image tag - prefix = remaining_text[:match.start()].strip() + prefix = remaining_text[: match.start()].strip() if prefix: final_list.append({"type": "text", "text": prefix}) - + # Add the image image_url = match.group(1) final_list.append({"type": "image_url", "image_url": {"url": image_url}}) - + # Update remaining text - remaining_text = remaining_text[match.end():].strip() - + remaining_text = remaining_text[match.end() :].strip() + return final_list diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 01a0e5a89a..21540bd32c 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -154,7 +154,7 @@ def parse_value(value, annotation): if v in allowed: return v - + raise ValueError(f"{value!r} is not one of {allowed!r}") if not isinstance(value, str): @@ -174,6 +174,7 @@ def parse_value(value, annotation): return str(candidate) raise + def get_annotation_name(annotation): origin = get_origin(annotation) args = get_args(annotation) @@ -193,6 +194,7 @@ def get_annotation_name(annotation): args_str = ", ".join(get_annotation_name(a) for a in args) return f"{get_annotation_name(origin)}[{args_str}]" + def get_field_description_string(fields: dict) -> str: field_descriptions = [] for idx, (k, v) in enumerate(fields.items()): @@ -220,7 +222,7 @@ def _format_input_list_field_value(value: List[Any]) -> str: if len(value) == 1: return _format_blob(value[0]) - return "\n".join([f"[{idx+1}] {_format_blob(txt)}" for idx, txt in enumerate(value)]) + return "\n".join([f"[{idx + 1}] {_format_blob(txt)}" for idx, txt in enumerate(value)]) def _format_blob(blob: str) -> str: