Skip to content

feat(dspy): custom type resolution in Signatures #8232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/docs/learn/programming/signatures.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ Your signatures can also have multiple input/output fields with types:

**Tip:** For fields, any valid variable names work! Field names should be semantically meaningful, but start simple and don't prematurely optimize keywords! Leave that kind of hacking to the DSPy compiler. For example, for summarization, it's probably fine to say `"document -> summary"`, `"text -> gist"`, or `"long_context -> tldr"`.

You can also add instructions to your inline signature, which can use variables at runtime. Use the `instructions` keyword argument to add instructions to your signature.

```python
toxicity = dspy.Predict(
'comment -> toxic: bool',
instructions="Mark as 'toxic' if the comment includes insults, harassment, or sarcastic derogatory remarks.")
```

### Example A: Sentiment Classification

Expand Down Expand Up @@ -157,6 +164,35 @@ Prediction(
)
```

## Type Resolution in Signatures

DSPy signatures support various annotation types:

1. **Basic types** like `str`, `int`, `bool`
2. **Typing module types** like `List[str]`, `Dict[str, int]`, `Optional[float]`. `Union[str, int]`
3. **Custom types** defined in your code
4. **Dot notation** for nested types with proper configuration
5. **Special data types** like `dspy.Image, dspy.History`

### Working with Custom Types

```python
# Simple custom type
class QueryResult(pydantic.BaseModel):
text: str
score: float

signature = dspy.Signature("query: str -> result: QueryResult")

class Container:
class Query(pydantic.BaseModel):
text: str
class Score(pydantic.BaseModel):
score: float

signature = dspy.Signature("query: Container.Query -> score: Container.Score")
```

## Using signatures to build modules & compiling them

While signatures are convenient for prototyping with structured inputs/outputs, that's not the only reason to use them!
Expand Down
132 changes: 121 additions & 11 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MySignature(dspy.Signature):
import re
import types
import typing
import sys
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple, Type, Union

Expand All @@ -41,9 +42,97 @@ class SignatureMeta(type(BaseModel)):
def __call__(cls, *args, **kwargs):
if cls is Signature:
# We don't create an actual Signature instance, instead, we create a new Signature class.
return make_signature(*args, **kwargs)
custom_types = kwargs.pop('custom_types', None)

if custom_types is None and args and isinstance(args[0], str):
custom_types = cls._detect_custom_types_from_caller(args[0])

return make_signature(*args, custom_types=custom_types, **kwargs)
return super().__call__(*args, **kwargs)

@staticmethod
def _detect_custom_types_from_caller(signature_str):
"""Detect custom types from the caller's frame based on the signature string.

Note: This method relies on Python's frame introspection which has some limitations:
1. May not work in all Python implementations (e.g., compiled with optimizations)
2. Looks up a limited number of frames in the call stack
3. Cannot find types that are imported but not in the caller's namespace

For more reliable custom type resolution, explicitly provide types using the
`custom_types` parameter when creating a Signature.
"""

# Extract potential type names from the signature string, including dotted names
# Match both simple types like 'MyType' and dotted names like 'Module.Type'
type_pattern = r':\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)'
type_names = re.findall(type_pattern, signature_str)
if not type_names:
return None

# Get type references from caller frames by walking the stack
found_types = {}

needed_types = set()
dotted_types = {}

for type_name in type_names:
parts = type_name.split('.')
base_name = parts[0]

if base_name not in typing.__dict__ and base_name not in __builtins__:
if len(parts) > 1:
dotted_types[type_name] = base_name
needed_types.add(base_name)
else:
needed_types.add(type_name)

if not needed_types:
return None

frame = None
try:
frame = sys._getframe(1) # Start one level up (skip this function)

max_frames = 100
frame_count = 0

while frame and needed_types and frame_count < max_frames:
frame_count += 1

for type_name in list(needed_types):
if type_name in frame.f_locals:
found_types[type_name] = frame.f_locals[type_name]
needed_types.remove(type_name)
elif frame.f_globals and type_name in frame.f_globals:
found_types[type_name] = frame.f_globals[type_name]
needed_types.remove(type_name)

# If we found all needed types, stop looking
if not needed_types:
break

frame = frame.f_back

if needed_types and frame_count >= max_frames:
import logging
logging.getLogger("dspy").warning(
f"Reached maximum frame search depth ({max_frames}) while looking for types: {needed_types}. "
"Consider providing custom_types explicitly to Signature."
)
except (AttributeError, ValueError):
# Handle environments where frame introspection is not available
import logging
logging.getLogger("dspy").debug(
"Frame introspection failed while trying to resolve custom types. "
"Consider providing custom_types explicitly to Signature."
)
finally:
if frame:
del frame

return found_types or None

def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
# At this point, the orders have been swapped already.
field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)]
Expand Down Expand Up @@ -282,6 +371,7 @@ def make_signature(
signature: Union[str, Dict[str, Tuple[type, FieldInfo]]],
instructions: Optional[str] = None,
signature_name: str = "StringSignature",
custom_types: Optional[Dict[str, Type]] = None,
) -> Type[Signature]:
"""Create a new Signature subclass with the specified fields and instructions.

Expand All @@ -292,6 +382,8 @@ def make_signature(
If not provided, defaults to a basic description of inputs and outputs.
signature_name: Optional string to name the generated Signature subclass.
Defaults to "StringSignature".
custom_types: Optional dictionary mapping type names to their actual type objects.
Useful for resolving custom types that aren't built-ins or in the typing module.

Returns:
A new signature class with the specified fields and instructions.
Expand All @@ -307,9 +399,21 @@ def make_signature(
"question": (str, InputField()),
"answer": (str, OutputField())
})

# Using custom types
class MyType:
pass

sig3 = make_signature("input: MyType -> output", custom_types={"MyType": MyType})
```
"""
fields = _parse_signature(signature) if isinstance(signature, str) else signature
# Prepare the names dictionary for type resolution
names = None
if custom_types:
names = dict(typing.__dict__)
names.update(custom_types)

fields = _parse_signature(signature, names) if isinstance(signature, str) else signature

# Validate the fields, this is important because we sometimes forget the
# slightly unintuitive syntax with tuples of (type, Field)
Expand Down Expand Up @@ -347,22 +451,22 @@ def make_signature(
)


def _parse_signature(signature: str) -> Dict[str, Tuple[Type, Field]]:
def _parse_signature(signature: str, names=None) -> Dict[str, Tuple[Type, Field]]:
if signature.count("->") != 1:
raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.")

inputs_str, outputs_str = signature.split("->")

fields = {}
for field_name, field_type in _parse_field_string(inputs_str):
for field_name, field_type in _parse_field_string(inputs_str, names):
fields[field_name] = (field_type, InputField())
for field_name, field_type in _parse_field_string(outputs_str):
for field_name, field_type in _parse_field_string(outputs_str, names):
fields[field_name] = (field_type, OutputField())

return fields


def _parse_field_string(field_string: str) -> Dict[str, str]:
def _parse_field_string(field_string: str, names=None) -> Dict[str, str]:
"""Extract the field name and type from field string in the string-based Signature.

It takes a string like "x: int, y: str" and returns a dictionary mapping field names to their types.
Expand All @@ -371,9 +475,9 @@ def _parse_field_string(field_string: str) -> Dict[str, str]:
"""

args = ast.parse(f"def f({field_string}): pass").body[0].args.args
names = [arg.arg for arg in args]
types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args]
return zip(names, types)
field_names = [arg.arg for arg in args]
types = [str if arg.annotation is None else _parse_type_node(arg.annotation, names) for arg in args]
return zip(field_names, types)


def _parse_type_node(node, names=None) -> Any:
Expand Down Expand Up @@ -446,10 +550,16 @@ def resolve_name(type_name: str):
if isinstance(node, ast.Attribute):
base = _parse_type_node(node.value, names)
attr_name = node.attr

if hasattr(base, attr_name):
return getattr(base, attr_name)
else:
raise ValueError(f"Unknown attribute: {attr_name} on {base}")

if isinstance(node.value, ast.Name):
full_name = f"{node.value.id}.{attr_name}"
if full_name in names:
return names[full_name]

raise ValueError(f"Unknown attribute: {attr_name} on {base}")

if isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
Expand Down
121 changes: 121 additions & 0 deletions tests/signatures/test_custom_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pydantic
import pytest
from typing import List, Dict, Any

import dspy
from dspy import Signature
from dspy.utils.dummies import DummyLM


def test_basic_custom_type_resolution():
"""Test basic custom type resolution with both explicit and automatic mapping."""
class CustomType(pydantic.BaseModel):
value: str

# Custom types can be explicitly mapped
explicit_sig = Signature(
"input: CustomType -> output: str",
custom_types={"CustomType": CustomType}
)
assert explicit_sig.input_fields["input"].annotation == CustomType

# Custom types can also be auto-resolved from caller's scope
auto_sig = Signature("input: CustomType -> output: str")
assert auto_sig.input_fields["input"].annotation == CustomType


def test_type_alias_for_nested_types():
"""Test using type aliases for nested types."""
class Container:
class NestedType(pydantic.BaseModel):
value: str

NestedType = Container.NestedType
alias_sig = Signature("input: str -> output: NestedType")
assert alias_sig.output_fields["output"].annotation == Container.NestedType

class Container2:
class Query(pydantic.BaseModel):
text: str
class Score(pydantic.BaseModel):
score: float

signature = dspy.Signature("query: Container2.Query -> score: Container2.Score")
assert signature.output_fields["score"].annotation == Container2.Score


class GlobalCustomType(pydantic.BaseModel):
"""A type defined at module level for testing module-level resolution."""
value: str
notes: str = ""


def test_module_level_type_resolution():
"""Test resolution of types defined at module level."""
# Module-level types can be auto-resolved
sig = Signature("name: str -> result: GlobalCustomType")
assert sig.output_fields["result"].annotation == GlobalCustomType


# Create module-level nested class for testing
class OuterContainer:
class InnerType(pydantic.BaseModel):
name: str
value: int


def test_recommended_patterns():
"""Test recommended patterns for working with custom types in signatures."""

# PATTERN 1: Local type with auto-resolution
class LocalType(pydantic.BaseModel):
value: str

sig1 = Signature("input: str -> output: LocalType")
assert sig1.output_fields["output"].annotation == LocalType

# PATTERN 2: Module-level type with auto-resolution
sig2 = Signature("input: str -> output: GlobalCustomType")
assert sig2.output_fields["output"].annotation == GlobalCustomType

# PATTERN 3: Nested type with dot notation
sig3 = Signature("input: str -> output: OuterContainer.InnerType")
assert sig3.output_fields["output"].annotation == OuterContainer.InnerType

# PATTERN 4: Nested type using alias
InnerTypeAlias = OuterContainer.InnerType
sig4 = Signature("input: str -> output: InnerTypeAlias")
assert sig4.output_fields["output"].annotation == InnerTypeAlias

# PATTERN 5: Nested type with dot notation
sig5 = Signature("input: str -> output: OuterContainer.InnerType")
assert sig5.output_fields["output"].annotation == OuterContainer.InnerType

def test_expected_failure():
# InnerType DNE when not OuterContainer.InnerTypes, so this type shouldnt be resolved
with pytest.raises(ValueError):
sig4 = Signature("input: str -> output: InnerType")
assert sig4.output_fields["output"].annotation == InnerType

def test_module_type_resolution():
class TestModule(dspy.Module):
def __init__(self):
super().__init__()
self.predict = dspy.Predict("input: str -> output: OuterContainer.InnerType")

def predict(self, input: str) -> str:
return input

module = TestModule()
sig = module.predict.signature
assert sig.output_fields["output"].annotation == OuterContainer.InnerType

def test_basic_custom_type_resolution():
class CustomType(pydantic.BaseModel):
value: str

sig = Signature("input: CustomType -> output: str", custom_types={"CustomType": CustomType})
assert sig.input_fields["input"].annotation == CustomType

sig = Signature("input: CustomType -> output: str")
assert sig.input_fields["input"].annotation == CustomType
Loading