Skip to content

Commit e05f6f9

Browse files
authored
feat(dspy): custom type resolution in Signatures (#8232)
* initial version - custom types working * refactor custom type class to be nicer * [will fail CI] Move tests to separate file + add module level test to fail * Dynamically walk call stack to find more types * ruff + extra * simplify docs
1 parent 7b325ad commit e05f6f9

File tree

4 files changed

+327
-25
lines changed

4 files changed

+327
-25
lines changed

docs/docs/learn/programming/signatures.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ Your signatures can also have multiple input/output fields with types:
3232

3333
**Tip:** For fields, any valid variable names work! Field names should be semantically meaningful, but start simple and don't prematurely optimize keywords! Leave that kind of hacking to the DSPy compiler. For example, for summarization, it's probably fine to say `"document -> summary"`, `"text -> gist"`, or `"long_context -> tldr"`.
3434

35+
You can also add instructions to your inline signature, which can use variables at runtime. Use the `instructions` keyword argument to add instructions to your signature.
36+
37+
```python
38+
toxicity = dspy.Predict(
39+
'comment -> toxic: bool',
40+
instructions="Mark as 'toxic' if the comment includes insults, harassment, or sarcastic derogatory remarks.")
41+
```
3542

3643
### Example A: Sentiment Classification
3744

@@ -157,6 +164,35 @@ Prediction(
157164
)
158165
```
159166

167+
## Type Resolution in Signatures
168+
169+
DSPy signatures support various annotation types:
170+
171+
1. **Basic types** like `str`, `int`, `bool`
172+
2. **Typing module types** like `List[str]`, `Dict[str, int]`, `Optional[float]`. `Union[str, int]`
173+
3. **Custom types** defined in your code
174+
4. **Dot notation** for nested types with proper configuration
175+
5. **Special data types** like `dspy.Image, dspy.History`
176+
177+
### Working with Custom Types
178+
179+
```python
180+
# Simple custom type
181+
class QueryResult(pydantic.BaseModel):
182+
text: str
183+
score: float
184+
185+
signature = dspy.Signature("query: str -> result: QueryResult")
186+
187+
class Container:
188+
class Query(pydantic.BaseModel):
189+
text: str
190+
class Score(pydantic.BaseModel):
191+
score: float
192+
193+
signature = dspy.Signature("query: Container.Query -> score: Container.Score")
194+
```
195+
160196
## Using signatures to build modules & compiling them
161197

162198
While signatures are convenient for prototyping with structured inputs/outputs, that's not the only reason to use them!

dspy/signatures/signature.py

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class MySignature(dspy.Signature):
2121
import re
2222
import types
2323
import typing
24+
import sys
2425
from copy import deepcopy
2526
from typing import Any, Dict, Optional, Tuple, Type, Union
2627

@@ -40,9 +41,97 @@ class SignatureMeta(type(BaseModel)):
4041
def __call__(cls, *args, **kwargs):
4142
if cls is Signature:
4243
# We don't create an actual Signature instance, instead, we create a new Signature class.
43-
return make_signature(*args, **kwargs)
44+
custom_types = kwargs.pop('custom_types', None)
45+
46+
if custom_types is None and args and isinstance(args[0], str):
47+
custom_types = cls._detect_custom_types_from_caller(args[0])
48+
49+
return make_signature(*args, custom_types=custom_types, **kwargs)
4450
return super().__call__(*args, **kwargs)
4551

52+
@staticmethod
53+
def _detect_custom_types_from_caller(signature_str):
54+
"""Detect custom types from the caller's frame based on the signature string.
55+
56+
Note: This method relies on Python's frame introspection which has some limitations:
57+
1. May not work in all Python implementations (e.g., compiled with optimizations)
58+
2. Looks up a limited number of frames in the call stack
59+
3. Cannot find types that are imported but not in the caller's namespace
60+
61+
For more reliable custom type resolution, explicitly provide types using the
62+
`custom_types` parameter when creating a Signature.
63+
"""
64+
65+
# Extract potential type names from the signature string, including dotted names
66+
# Match both simple types like 'MyType' and dotted names like 'Module.Type'
67+
type_pattern = r':\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)'
68+
type_names = re.findall(type_pattern, signature_str)
69+
if not type_names:
70+
return None
71+
72+
# Get type references from caller frames by walking the stack
73+
found_types = {}
74+
75+
needed_types = set()
76+
dotted_types = {}
77+
78+
for type_name in type_names:
79+
parts = type_name.split('.')
80+
base_name = parts[0]
81+
82+
if base_name not in typing.__dict__ and base_name not in __builtins__:
83+
if len(parts) > 1:
84+
dotted_types[type_name] = base_name
85+
needed_types.add(base_name)
86+
else:
87+
needed_types.add(type_name)
88+
89+
if not needed_types:
90+
return None
91+
92+
frame = None
93+
try:
94+
frame = sys._getframe(1) # Start one level up (skip this function)
95+
96+
max_frames = 100
97+
frame_count = 0
98+
99+
while frame and needed_types and frame_count < max_frames:
100+
frame_count += 1
101+
102+
for type_name in list(needed_types):
103+
if type_name in frame.f_locals:
104+
found_types[type_name] = frame.f_locals[type_name]
105+
needed_types.remove(type_name)
106+
elif frame.f_globals and type_name in frame.f_globals:
107+
found_types[type_name] = frame.f_globals[type_name]
108+
needed_types.remove(type_name)
109+
110+
# If we found all needed types, stop looking
111+
if not needed_types:
112+
break
113+
114+
frame = frame.f_back
115+
116+
if needed_types and frame_count >= max_frames:
117+
import logging
118+
logging.getLogger("dspy").warning(
119+
f"Reached maximum frame search depth ({max_frames}) while looking for types: {needed_types}. "
120+
"Consider providing custom_types explicitly to Signature."
121+
)
122+
except (AttributeError, ValueError):
123+
# Handle environments where frame introspection is not available
124+
import logging
125+
logging.getLogger("dspy").debug(
126+
"Frame introspection failed while trying to resolve custom types. "
127+
"Consider providing custom_types explicitly to Signature."
128+
)
129+
finally:
130+
if frame:
131+
del frame
132+
133+
return found_types or None
134+
46135
def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
47136
# At this point, the orders have been swapped already.
48137
field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)]
@@ -281,6 +370,7 @@ def make_signature(
281370
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
282371
instructions: Optional[str] = None,
283372
signature_name: str = "StringSignature",
373+
custom_types: Optional[Dict[str, Type]] = None,
284374
) -> Type[Signature]:
285375
"""Create a new Signature subclass with the specified fields and instructions.
286376
@@ -291,6 +381,8 @@ def make_signature(
291381
If not provided, defaults to a basic description of inputs and outputs.
292382
signature_name: Optional string to name the generated Signature subclass.
293383
Defaults to "StringSignature".
384+
custom_types: Optional dictionary mapping type names to their actual type objects.
385+
Useful for resolving custom types that aren't built-ins or in the typing module.
294386
295387
Returns:
296388
A new signature class with the specified fields and instructions.
@@ -306,9 +398,21 @@ def make_signature(
306398
"question": (str, InputField()),
307399
"answer": (str, OutputField())
308400
})
401+
402+
# Using custom types
403+
class MyType:
404+
pass
405+
406+
sig3 = make_signature("input: MyType -> output", custom_types={"MyType": MyType})
309407
```
310408
"""
311-
fields = _parse_signature(signature) if isinstance(signature, str) else signature
409+
# Prepare the names dictionary for type resolution
410+
names = None
411+
if custom_types:
412+
names = dict(typing.__dict__)
413+
names.update(custom_types)
414+
415+
fields = _parse_signature(signature, names) if isinstance(signature, str) else signature
312416

313417
# Validate the fields, this is important because we sometimes forget the
314418
# slightly unintuitive syntax with tuples of (type, Field)
@@ -346,22 +450,22 @@ def make_signature(
346450
)
347451

348452

349-
def _parse_signature(signature: str) -> Dict[str, Tuple[Type, Field]]:
453+
def _parse_signature(signature: str, names=None) -> Dict[str, Tuple[Type, Field]]:
350454
if signature.count("->") != 1:
351455
raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.")
352456

353457
inputs_str, outputs_str = signature.split("->")
354458

355459
fields = {}
356-
for field_name, field_type in _parse_field_string(inputs_str):
460+
for field_name, field_type in _parse_field_string(inputs_str, names):
357461
fields[field_name] = (field_type, InputField())
358-
for field_name, field_type in _parse_field_string(outputs_str):
462+
for field_name, field_type in _parse_field_string(outputs_str, names):
359463
fields[field_name] = (field_type, OutputField())
360464

361465
return fields
362466

363467

364-
def _parse_field_string(field_string: str) -> Dict[str, str]:
468+
def _parse_field_string(field_string: str, names=None) -> Dict[str, str]:
365469
"""Extract the field name and type from field string in the string-based Signature.
366470
367471
It takes a string like "x: int, y: str" and returns a dictionary mapping field names to their types.
@@ -370,9 +474,9 @@ def _parse_field_string(field_string: str) -> Dict[str, str]:
370474
"""
371475

372476
args = ast.parse(f"def f({field_string}): pass").body[0].args.args
373-
names = [arg.arg for arg in args]
374-
types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args]
375-
return zip(names, types)
477+
field_names = [arg.arg for arg in args]
478+
types = [str if arg.annotation is None else _parse_type_node(arg.annotation, names) for arg in args]
479+
return zip(field_names, types)
376480

377481

378482
def _parse_type_node(node, names=None) -> Any:
@@ -445,10 +549,16 @@ def resolve_name(type_name: str):
445549
if isinstance(node, ast.Attribute):
446550
base = _parse_type_node(node.value, names)
447551
attr_name = node.attr
552+
448553
if hasattr(base, attr_name):
449554
return getattr(base, attr_name)
450-
else:
451-
raise ValueError(f"Unknown attribute: {attr_name} on {base}")
555+
556+
if isinstance(node.value, ast.Name):
557+
full_name = f"{node.value.id}.{attr_name}"
558+
if full_name in names:
559+
return names[full_name]
560+
561+
raise ValueError(f"Unknown attribute: {attr_name} on {base}")
452562

453563
if isinstance(node, ast.Subscript):
454564
base_type = _parse_type_node(node.value, names)

tests/signatures/test_custom_types.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pydantic
2+
import pytest
3+
from typing import List, Dict, Any
4+
5+
import dspy
6+
from dspy import Signature
7+
from dspy.utils.dummies import DummyLM
8+
9+
10+
def test_basic_custom_type_resolution():
11+
"""Test basic custom type resolution with both explicit and automatic mapping."""
12+
class CustomType(pydantic.BaseModel):
13+
value: str
14+
15+
# Custom types can be explicitly mapped
16+
explicit_sig = Signature(
17+
"input: CustomType -> output: str",
18+
custom_types={"CustomType": CustomType}
19+
)
20+
assert explicit_sig.input_fields["input"].annotation == CustomType
21+
22+
# Custom types can also be auto-resolved from caller's scope
23+
auto_sig = Signature("input: CustomType -> output: str")
24+
assert auto_sig.input_fields["input"].annotation == CustomType
25+
26+
27+
def test_type_alias_for_nested_types():
28+
"""Test using type aliases for nested types."""
29+
class Container:
30+
class NestedType(pydantic.BaseModel):
31+
value: str
32+
33+
NestedType = Container.NestedType
34+
alias_sig = Signature("input: str -> output: NestedType")
35+
assert alias_sig.output_fields["output"].annotation == Container.NestedType
36+
37+
class Container2:
38+
class Query(pydantic.BaseModel):
39+
text: str
40+
class Score(pydantic.BaseModel):
41+
score: float
42+
43+
signature = dspy.Signature("query: Container2.Query -> score: Container2.Score")
44+
assert signature.output_fields["score"].annotation == Container2.Score
45+
46+
47+
class GlobalCustomType(pydantic.BaseModel):
48+
"""A type defined at module level for testing module-level resolution."""
49+
value: str
50+
notes: str = ""
51+
52+
53+
def test_module_level_type_resolution():
54+
"""Test resolution of types defined at module level."""
55+
# Module-level types can be auto-resolved
56+
sig = Signature("name: str -> result: GlobalCustomType")
57+
assert sig.output_fields["result"].annotation == GlobalCustomType
58+
59+
60+
# Create module-level nested class for testing
61+
class OuterContainer:
62+
class InnerType(pydantic.BaseModel):
63+
name: str
64+
value: int
65+
66+
67+
def test_recommended_patterns():
68+
"""Test recommended patterns for working with custom types in signatures."""
69+
70+
# PATTERN 1: Local type with auto-resolution
71+
class LocalType(pydantic.BaseModel):
72+
value: str
73+
74+
sig1 = Signature("input: str -> output: LocalType")
75+
assert sig1.output_fields["output"].annotation == LocalType
76+
77+
# PATTERN 2: Module-level type with auto-resolution
78+
sig2 = Signature("input: str -> output: GlobalCustomType")
79+
assert sig2.output_fields["output"].annotation == GlobalCustomType
80+
81+
# PATTERN 3: Nested type with dot notation
82+
sig3 = Signature("input: str -> output: OuterContainer.InnerType")
83+
assert sig3.output_fields["output"].annotation == OuterContainer.InnerType
84+
85+
# PATTERN 4: Nested type using alias
86+
InnerTypeAlias = OuterContainer.InnerType
87+
sig4 = Signature("input: str -> output: InnerTypeAlias")
88+
assert sig4.output_fields["output"].annotation == InnerTypeAlias
89+
90+
# PATTERN 5: Nested type with dot notation
91+
sig5 = Signature("input: str -> output: OuterContainer.InnerType")
92+
assert sig5.output_fields["output"].annotation == OuterContainer.InnerType
93+
94+
def test_expected_failure():
95+
# InnerType DNE when not OuterContainer.InnerTypes, so this type shouldnt be resolved
96+
with pytest.raises(ValueError):
97+
sig4 = Signature("input: str -> output: InnerType")
98+
assert sig4.output_fields["output"].annotation == InnerType
99+
100+
def test_module_type_resolution():
101+
class TestModule(dspy.Module):
102+
def __init__(self):
103+
super().__init__()
104+
self.predict = dspy.Predict("input: str -> output: OuterContainer.InnerType")
105+
106+
def predict(self, input: str) -> str:
107+
return input
108+
109+
module = TestModule()
110+
sig = module.predict.signature
111+
assert sig.output_fields["output"].annotation == OuterContainer.InnerType
112+
113+
def test_basic_custom_type_resolution():
114+
class CustomType(pydantic.BaseModel):
115+
value: str
116+
117+
sig = Signature("input: CustomType -> output: str", custom_types={"CustomType": CustomType})
118+
assert sig.input_fields["input"].annotation == CustomType
119+
120+
sig = Signature("input: CustomType -> output: str")
121+
assert sig.input_fields["input"].annotation == CustomType

0 commit comments

Comments
 (0)