Skip to content

Commit 4f154a7

Browse files
Fix the custom type extraction in dspy.BaseType (#8320)
* init * increment * fix comments
1 parent 60d6f94 commit 4f154a7

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

dspy/adapters/types/base_type.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import re
3-
import inspect
43
from typing import Any, Union, get_args, get_origin
54

65
import json_repair
@@ -42,18 +41,13 @@ def extract_custom_type_from_annotation(cls, annotation):
4241
This is used to extract all custom types from the annotation of a field, while the annotation can
4342
have arbitrary level of nesting. For example, we detect `Tool` is in `list[dict[str, Tool]]`.
4443
"""
45-
# Direct match. Some typing constructs (like `typing.Any`, `TypeAlias`,
46-
# or weird internals) may pass `isinstance(..., type)` but are not
47-
# valid classes for `issubclass`. We defensively guard against this by
48-
# using `inspect.isclass` and wrapping the call in a try/except block.
44+
# Direct match. Nested type like `list[dict[str, Event]]` passes `isinstance(annotation, type)` in python 3.10
45+
# while fails in python 3.11. To accomodate users using python 3.10, we need to capture the error and ignore it.
4946
try:
50-
if inspect.isclass(annotation) and issubclass(annotation, cls):
47+
if isinstance(annotation, type) and issubclass(annotation, cls):
5148
return [annotation]
5249
except TypeError:
53-
# `issubclass` can raise `TypeError` if the argument is not actually
54-
# a class (even if `inspect.isclass` thought otherwise). In these
55-
# cases we just ignore the annotation.
56-
return []
50+
pass
5751

5852
origin = get_origin(annotation)
5953
if origin is None:

tests/adapters/test_base_type.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import dspy
2+
from typing import Optional
3+
import pydantic
4+
5+
6+
def test_basic_extract_custom_type_from_annotation():
7+
class Event(dspy.BaseType):
8+
event_name: str
9+
start_date_time: str
10+
end_date_time: Optional[str]
11+
location: Optional[str]
12+
13+
class ExtractEvent(dspy.Signature):
14+
"""Extract all events from the email content."""
15+
16+
email: str = dspy.InputField()
17+
event: Event = dspy.OutputField()
18+
19+
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvent.output_fields["event"].annotation) == [Event]
20+
21+
class ExtractEvents(dspy.Signature):
22+
"""Extract all events from the email content."""
23+
24+
email: str = dspy.InputField()
25+
events: list[Event] = dspy.OutputField()
26+
27+
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [
28+
Event
29+
]
30+
31+
32+
def test_extract_custom_type_from_annotation_with_nested_type():
33+
class Event(dspy.BaseType):
34+
event_name: str
35+
start_date_time: str
36+
end_date_time: Optional[str]
37+
location: Optional[str]
38+
39+
class EventIdentifier(dspy.BaseType):
40+
model_config = pydantic.ConfigDict(frozen=True) # Make it hashable
41+
event_id: str
42+
event_name: str
43+
44+
class ExtractEvents(dspy.Signature):
45+
"""Extract all events from the email content."""
46+
47+
email: str = dspy.InputField()
48+
events: list[dict[EventIdentifier, Event]] = dspy.OutputField()
49+
50+
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [
51+
EventIdentifier,
52+
Event,
53+
]

0 commit comments

Comments
 (0)