Skip to content

Commit 5e17bcf

Browse files
Provide a standard base class for creating custom Signature field type (#8217)
* support for custom types in DSPy signatures * fix completed demos * rename custom formatting function * init * increment * increment * increment * add test * better arrangement of code * fix test * address comments * add comment --------- Co-authored-by: Arnav Singhvi <arnav11.singhvi@gmail.com>
1 parent 5d31cd1 commit 5e17bcf

File tree

10 files changed

+366
-68
lines changed

10 files changed

+366
-68
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dspy.evaluate import Evaluate # isort: skip
1010
from dspy.clients import * # isort: skip
11-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, History # isort: skip
11+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, History, BaseType # isort: skip
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
5-
from dspy.adapters.types import History, Image
5+
from dspy.adapters.types import History, Image, BaseType
66

77
__all__ = [
88
"Adapter",
99
"ChatAdapter",
10+
"BaseType",
1011
"History",
1112
"Image",
1213
"JSONAdapter",

dspy/adapters/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import TYPE_CHECKING, Any, Optional, Type
22

33
from dspy.adapters.types import History
4-
from dspy.adapters.types.image import try_expand_image_tags
4+
from dspy.adapters.types.base_type import split_message_content_for_custom_types
55
from dspy.signatures.signature import Signature
66
from dspy.utils.callback import BaseCallback, with_callbacks
77

@@ -141,7 +141,7 @@ def format(
141141
content = self.format_user_message_content(signature, inputs_copy, main_request=True)
142142
messages.append({"role": "user", "content": content})
143143

144-
messages = try_expand_image_tags(messages)
144+
messages = split_message_content_for_custom_types(messages)
145145
return messages
146146

147147
def format_field_description(self, signature: Type[Signature]) -> str:

dspy/adapters/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dspy.adapters.types.history import History
22
from dspy.adapters.types.image import Image
3+
from dspy.adapters.types.base_type import BaseType
34

4-
__all__ = ["History", "Image"]
5+
__all__ = ["History", "Image", "BaseType"]

dspy/adapters/types/base_type.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
import re
3+
from typing import Any
4+
5+
import json_repair
6+
import pydantic
7+
8+
CUSTOM_TYPE_START_IDENTIFIER = "<<CUSTOM-TYPE-START-IDENTIFIER>>"
9+
CUSTOM_TYPE_END_IDENTIFIER = "<<CUSTOM-TYPE-END-IDENTIFIER>>"
10+
11+
12+
class BaseType(pydantic.BaseModel):
13+
"""Base class to support creating custom types for DSPy signatures.
14+
15+
This is the parent class of DSPy custom types, e.g, dspy.Image. Subclasses must implement the `format` method to
16+
return a list of dictionaries (same as the Array of content parts in the OpenAI API user message's content field).
17+
18+
Example:
19+
20+
```python
21+
class Image(BaseType):
22+
url: str
23+
24+
def format(self) -> list[dict[str, Any]]:
25+
return [{"type": "image_url", "image_url": {"url": self.url}}]
26+
```
27+
"""
28+
29+
def format(self) -> list[dict[str, Any]]:
30+
raise NotImplementedError
31+
32+
@pydantic.model_serializer()
33+
def serialize_model(self):
34+
return f"{CUSTOM_TYPE_START_IDENTIFIER}{self.format()}{CUSTOM_TYPE_END_IDENTIFIER}"
35+
36+
37+
def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
38+
"""Split user message content into a list of content blocks.
39+
40+
This method splits each user message's content in the `messages` list to be a list of content block, so that
41+
the custom types like `dspy.Image` can be properly formatted for better quality. For example, the split content
42+
may look like below if the user message has a `dspy.Image` object:
43+
44+
```
45+
[
46+
{"type": "text", "text": "{text_before_image}"},
47+
{"type": "image_url", "image_url": {"url": "{image_url}"}},
48+
{"type": "text", "text": "{text_after_image}"},
49+
]
50+
```
51+
52+
This is implemented by finding the `<<CUSTOM-TYPE-START-IDENTIFIER>>` and `<<CUSTOM-TYPE-END-IDENTIFIER>>`
53+
in the user message content and splitting the content around them. The `<<CUSTOM-TYPE-START-IDENTIFIER>>`
54+
and `<<CUSTOM-TYPE-END-IDENTIFIER>>` are the reserved identifiers for the custom types as in `dspy.BaseType`.
55+
56+
Args:
57+
messages: a list of messages sent to the LM. The format is the same as [OpenAI API's messages
58+
format](https://platform.openai.com/docs/guides/chat-completions/response-format).
59+
60+
Returns:
61+
A list of messages with the content split into a list of content blocks around custom types content.
62+
"""
63+
for message in messages:
64+
if message["role"] != "user":
65+
# Custom type messages are only in user messages
66+
continue
67+
68+
pattern = rf"{CUSTOM_TYPE_START_IDENTIFIER}(.*?){CUSTOM_TYPE_END_IDENTIFIER}"
69+
result = []
70+
last_end = 0
71+
# DSPy adapter always formats user input into a string content before custom type splitting
72+
content: str = message["content"]
73+
74+
for match in re.finditer(pattern, content, re.DOTALL):
75+
start, end = match.span()
76+
77+
# Add text before the current block
78+
if start > last_end:
79+
result.append({"type": "text", "text": content[last_end:start]})
80+
81+
# Parse the JSON inside the block
82+
custom_type_content = match.group(1).strip()
83+
try:
84+
parsed = json_repair.loads(custom_type_content)
85+
for custom_type_content in parsed:
86+
result.append(custom_type_content)
87+
except json.JSONDecodeError:
88+
# fallback to raw string if it's not valid JSON
89+
parsed = {"type": "text", "text": custom_type_content}
90+
result.append(parsed)
91+
92+
last_end = end
93+
94+
if last_end == 0:
95+
# No custom type found, return the original message
96+
continue
97+
98+
# Add any remaining text after the last match
99+
if last_end < len(content):
100+
result.append({"type": "text", "text": content[last_end:]})
101+
102+
message["content"] = result
103+
104+
return messages

dspy/adapters/types/image.py

Lines changed: 11 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import io
33
import mimetypes
44
import os
5-
import re
6-
from typing import Any, Dict, List, Union
5+
from typing import Any, Union
76
from urllib.parse import urlparse
87

98
import pydantic
109
import requests
1110

11+
from dspy.adapters.types.base_type import BaseType
12+
1213
try:
1314
from PIL import Image as PILImage
1415

@@ -17,7 +18,7 @@
1718
PIL_AVAILABLE = False
1819

1920

20-
class Image(pydantic.BaseModel):
21+
class Image(BaseType):
2122
url: str
2223

2324
model_config = {
@@ -27,6 +28,13 @@ class Image(pydantic.BaseModel):
2728
"extra": "forbid",
2829
}
2930

31+
def format(self) -> Union[list[dict[str, Any]], str]:
32+
try:
33+
image_url = encode_image(self.url)
34+
except Exception as e:
35+
raise ValueError(f"Failed to format image for DSPy: {e}")
36+
return [{"type": "image_url", "image_url": {"url": image_url}}]
37+
3038
@pydantic.model_validator(mode="before")
3139
@classmethod
3240
def validate_input(cls, values):
@@ -55,10 +63,6 @@ def from_file(cls, file_path: str):
5563
def from_PIL(cls, pil_image): # noqa: N802
5664
return cls(url=encode_image(pil_image))
5765

58-
@pydantic.model_serializer()
59-
def serialize_model(self):
60-
return "<DSPY_IMAGE_START>" + self.url + "<DSPY_IMAGE_END>"
61-
6266
def __str__(self):
6367
return self.serialize_model()
6468

@@ -197,54 +201,3 @@ def is_image(obj) -> bool:
197201
elif is_url(obj):
198202
return True
199203
return False
200-
201-
202-
def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
203-
"""Try to expand image tags in the messages."""
204-
for message in messages:
205-
# NOTE: Assumption that content is a string
206-
if "content" in message and "<DSPY_IMAGE_START>" in message["content"]:
207-
message["content"] = expand_image_tags(message["content"])
208-
return messages
209-
210-
211-
def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
212-
"""Expand image tags in the text. If there are any image tags,
213-
turn it from a content string into a content list of texts and image urls.
214-
215-
Args:
216-
text: The text content that may contain image tags
217-
218-
Returns:
219-
Either the original string if no image tags, or a list of content dicts
220-
with text and image_url entries
221-
"""
222-
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'
223-
224-
# If no image tags, return original text
225-
if not re.search(image_tag_regex, text):
226-
return text
227-
228-
final_list = []
229-
remaining_text = text
230-
231-
while remaining_text:
232-
match = re.search(image_tag_regex, remaining_text)
233-
if not match:
234-
if remaining_text.strip():
235-
final_list.append({"type": "text", "text": remaining_text.strip()})
236-
break
237-
238-
# Get text before the image tag
239-
prefix = remaining_text[: match.start()].strip()
240-
if prefix:
241-
final_list.append({"type": "text", "text": prefix})
242-
243-
# Add the image
244-
image_url = match.group(1)
245-
final_list.append({"type": "image_url", "image_url": {"url": image_url}})
246-
247-
# Update remaining text
248-
remaining_text = remaining_text[match.end() :].strip()
249-
250-
return final_list

dspy/signatures/signature.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class MySignature(dspy.Signature):
2727
from pydantic import BaseModel, Field, create_model
2828
from pydantic.fields import FieldInfo
2929

30-
from dspy.adapters.types.image import Image # noqa: F401
3130
from dspy.signatures.field import InputField, OutputField
3231

3332

tests/adapters/test_chat_adapter.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import dspy
7+
import pydantic
78

89

910
@pytest.mark.parametrize(
@@ -94,3 +95,121 @@ async def test_chat_adapter_async_call():
9495
lm = dspy.utils.DummyLM([{"answer": "Paris"}])
9596
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
9697
assert result == [{"answer": "Paris"}]
98+
99+
100+
def test_chat_adapter_formats_image():
101+
# Test basic image formatting
102+
image = dspy.Image(url="https://example.com/image.jpg")
103+
104+
class MySignature(dspy.Signature):
105+
image: dspy.Image = dspy.InputField()
106+
text: str = dspy.OutputField()
107+
108+
adapter = dspy.ChatAdapter()
109+
messages = adapter.format(MySignature, [], {"image": image})
110+
111+
assert len(messages) == 2
112+
user_message_content = messages[1]["content"]
113+
assert user_message_content is not None
114+
115+
# The message should have 3 chunks of types: text, image_url, text
116+
assert len(user_message_content) == 3
117+
assert user_message_content[0]["type"] == "text"
118+
assert user_message_content[2]["type"] == "text"
119+
120+
# Assert that the image is formatted correctly
121+
expected_image_content = {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
122+
assert expected_image_content in user_message_content
123+
124+
125+
def test_chat_adapter_formats_image_with_few_shot_examples():
126+
class MySignature(dspy.Signature):
127+
image: dspy.Image = dspy.InputField()
128+
text: str = dspy.OutputField()
129+
130+
adapter = dspy.ChatAdapter()
131+
132+
demos = [
133+
dspy.Example(
134+
image=dspy.Image(url="https://example.com/image1.jpg"),
135+
text="This is a test image",
136+
),
137+
dspy.Example(
138+
image=dspy.Image(url="https://example.com/image2.jpg"),
139+
text="This is another test image",
140+
),
141+
]
142+
messages = adapter.format(MySignature, demos, {"image": dspy.Image(url="https://example.com/image3.jpg")})
143+
144+
# 1 system message, 2 few shot examples (1 user and assistant message for each example), 1 user message
145+
assert len(messages) == 6
146+
147+
assert {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} in messages[1]["content"]
148+
assert {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} in messages[3]["content"]
149+
assert {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} in messages[5]["content"]
150+
151+
152+
def test_chat_adapter_formats_image_with_nested_images():
153+
class ImageWrapper(pydantic.BaseModel):
154+
images: list[dspy.Image]
155+
tag: list[str]
156+
157+
class MySignature(dspy.Signature):
158+
image: ImageWrapper = dspy.InputField()
159+
text: str = dspy.OutputField()
160+
161+
image1 = dspy.Image(url="https://example.com/image1.jpg")
162+
image2 = dspy.Image(url="https://example.com/image2.jpg")
163+
image3 = dspy.Image(url="https://example.com/image3.jpg")
164+
165+
image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"])
166+
167+
adapter = dspy.ChatAdapter()
168+
messages = adapter.format(MySignature, [], {"image": image_wrapper})
169+
170+
expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
171+
expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
172+
expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
173+
174+
assert expected_image1_content in messages[1]["content"]
175+
assert expected_image2_content in messages[1]["content"]
176+
assert expected_image3_content in messages[1]["content"]
177+
178+
179+
def test_chat_adapter_formats_image_with_few_shot_examples_with_nested_images():
180+
class ImageWrapper(pydantic.BaseModel):
181+
images: list[dspy.Image]
182+
tag: list[str]
183+
184+
class MySignature(dspy.Signature):
185+
image: ImageWrapper = dspy.InputField()
186+
text: str = dspy.OutputField()
187+
188+
image1 = dspy.Image(url="https://example.com/image1.jpg")
189+
image2 = dspy.Image(url="https://example.com/image2.jpg")
190+
image3 = dspy.Image(url="https://example.com/image3.jpg")
191+
192+
image_wrapper = ImageWrapper(images=[image1, image2, image3], tag=["test", "example"])
193+
demos = [
194+
dspy.Example(
195+
image=image_wrapper,
196+
text="This is a test image",
197+
),
198+
]
199+
200+
image_wrapper_2 = ImageWrapper(images=[dspy.Image(url="https://example.com/image4.jpg")], tag=["test", "example"])
201+
adapter = dspy.ChatAdapter()
202+
messages = adapter.format(MySignature, demos, {"image": image_wrapper_2})
203+
204+
assert len(messages) == 4
205+
206+
# Image information in the few-shot example's user message
207+
expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
208+
expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
209+
expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
210+
assert expected_image1_content in messages[1]["content"]
211+
assert expected_image2_content in messages[1]["content"]
212+
assert expected_image3_content in messages[1]["content"]
213+
214+
# The query image is formatted in the last user message
215+
assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]

0 commit comments

Comments
 (0)