diff --git a/docs/docs/learn/programming/signatures.md b/docs/docs/learn/programming/signatures.md index 4e0a9c2f13..aa47236f0f 100644 --- a/docs/docs/learn/programming/signatures.md +++ b/docs/docs/learn/programming/signatures.md @@ -32,6 +32,13 @@ Your signatures can also have multiple input/output fields with types: **Tip:** For fields, any valid variable names work! Field names should be semantically meaningful, but start simple and don't prematurely optimize keywords! Leave that kind of hacking to the DSPy compiler. For example, for summarization, it's probably fine to say `"document -> summary"`, `"text -> gist"`, or `"long_context -> tldr"`. +You can also add instructions to your inline signature, which can use variables at runtime. Use the `instructions` keyword argument to add instructions to your signature. + +```python +toxicity = dspy.Predict( + 'comment -> toxic: bool', + instructions="Mark as 'toxic' if the comment includes insults, harassment, or sarcastic derogatory remarks.") +``` ### Example A: Sentiment Classification @@ -157,6 +164,35 @@ Prediction( ) ``` +## Type Resolution in Signatures + +DSPy signatures support various annotation types: + +1. **Basic types** like `str`, `int`, `bool` +2. **Typing module types** like `List[str]`, `Dict[str, int]`, `Optional[float]`. `Union[str, int]` +3. **Custom types** defined in your code +4. **Dot notation** for nested types with proper configuration +5. **Special data types** like `dspy.Image, dspy.History` + +### Working with Custom Types + +```python +# Simple custom type +class QueryResult(pydantic.BaseModel): + text: str + score: float + +signature = dspy.Signature("query: str -> result: QueryResult") + +class Container: + class Query(pydantic.BaseModel): + text: str + class Score(pydantic.BaseModel): + score: float + +signature = dspy.Signature("query: Container.Query -> score: Container.Score") +``` + ## Using signatures to build modules & compiling them While signatures are convenient for prototyping with structured inputs/outputs, that's not the only reason to use them! diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 89a07f897b..af4225588a 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -21,6 +21,7 @@ class MySignature(dspy.Signature): import re import types import typing +import sys from copy import deepcopy from typing import Any, Dict, Optional, Tuple, Type, Union @@ -41,9 +42,97 @@ class SignatureMeta(type(BaseModel)): def __call__(cls, *args, **kwargs): if cls is Signature: # We don't create an actual Signature instance, instead, we create a new Signature class. - return make_signature(*args, **kwargs) + custom_types = kwargs.pop('custom_types', None) + + if custom_types is None and args and isinstance(args[0], str): + custom_types = cls._detect_custom_types_from_caller(args[0]) + + return make_signature(*args, custom_types=custom_types, **kwargs) return super().__call__(*args, **kwargs) + @staticmethod + def _detect_custom_types_from_caller(signature_str): + """Detect custom types from the caller's frame based on the signature string. + + Note: This method relies on Python's frame introspection which has some limitations: + 1. May not work in all Python implementations (e.g., compiled with optimizations) + 2. Looks up a limited number of frames in the call stack + 3. Cannot find types that are imported but not in the caller's namespace + + For more reliable custom type resolution, explicitly provide types using the + `custom_types` parameter when creating a Signature. + """ + + # Extract potential type names from the signature string, including dotted names + # Match both simple types like 'MyType' and dotted names like 'Module.Type' + type_pattern = r':\s*([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)*)' + type_names = re.findall(type_pattern, signature_str) + if not type_names: + return None + + # Get type references from caller frames by walking the stack + found_types = {} + + needed_types = set() + dotted_types = {} + + for type_name in type_names: + parts = type_name.split('.') + base_name = parts[0] + + if base_name not in typing.__dict__ and base_name not in __builtins__: + if len(parts) > 1: + dotted_types[type_name] = base_name + needed_types.add(base_name) + else: + needed_types.add(type_name) + + if not needed_types: + return None + + frame = None + try: + frame = sys._getframe(1) # Start one level up (skip this function) + + max_frames = 100 + frame_count = 0 + + while frame and needed_types and frame_count < max_frames: + frame_count += 1 + + for type_name in list(needed_types): + if type_name in frame.f_locals: + found_types[type_name] = frame.f_locals[type_name] + needed_types.remove(type_name) + elif frame.f_globals and type_name in frame.f_globals: + found_types[type_name] = frame.f_globals[type_name] + needed_types.remove(type_name) + + # If we found all needed types, stop looking + if not needed_types: + break + + frame = frame.f_back + + if needed_types and frame_count >= max_frames: + import logging + logging.getLogger("dspy").warning( + f"Reached maximum frame search depth ({max_frames}) while looking for types: {needed_types}. " + "Consider providing custom_types explicitly to Signature." + ) + except (AttributeError, ValueError): + # Handle environments where frame introspection is not available + import logging + logging.getLogger("dspy").debug( + "Frame introspection failed while trying to resolve custom types. " + "Consider providing custom_types explicitly to Signature." + ) + finally: + if frame: + del frame + + return found_types or None + def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 # At this point, the orders have been swapped already. field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)] @@ -282,6 +371,7 @@ def make_signature( signature: Union[str, Dict[str, Tuple[type, FieldInfo]]], instructions: Optional[str] = None, signature_name: str = "StringSignature", + custom_types: Optional[Dict[str, Type]] = None, ) -> Type[Signature]: """Create a new Signature subclass with the specified fields and instructions. @@ -292,6 +382,8 @@ def make_signature( If not provided, defaults to a basic description of inputs and outputs. signature_name: Optional string to name the generated Signature subclass. Defaults to "StringSignature". + custom_types: Optional dictionary mapping type names to their actual type objects. + Useful for resolving custom types that aren't built-ins or in the typing module. Returns: A new signature class with the specified fields and instructions. @@ -307,9 +399,21 @@ def make_signature( "question": (str, InputField()), "answer": (str, OutputField()) }) + + # Using custom types + class MyType: + pass + + sig3 = make_signature("input: MyType -> output", custom_types={"MyType": MyType}) ``` """ - fields = _parse_signature(signature) if isinstance(signature, str) else signature + # Prepare the names dictionary for type resolution + names = None + if custom_types: + names = dict(typing.__dict__) + names.update(custom_types) + + fields = _parse_signature(signature, names) if isinstance(signature, str) else signature # Validate the fields, this is important because we sometimes forget the # slightly unintuitive syntax with tuples of (type, Field) @@ -347,22 +451,22 @@ def make_signature( ) -def _parse_signature(signature: str) -> Dict[str, Tuple[Type, Field]]: +def _parse_signature(signature: str, names=None) -> Dict[str, Tuple[Type, Field]]: if signature.count("->") != 1: raise ValueError(f"Invalid signature format: '{signature}', must contain exactly one '->'.") inputs_str, outputs_str = signature.split("->") fields = {} - for field_name, field_type in _parse_field_string(inputs_str): + for field_name, field_type in _parse_field_string(inputs_str, names): fields[field_name] = (field_type, InputField()) - for field_name, field_type in _parse_field_string(outputs_str): + for field_name, field_type in _parse_field_string(outputs_str, names): fields[field_name] = (field_type, OutputField()) return fields -def _parse_field_string(field_string: str) -> Dict[str, str]: +def _parse_field_string(field_string: str, names=None) -> Dict[str, str]: """Extract the field name and type from field string in the string-based Signature. It takes a string like "x: int, y: str" and returns a dictionary mapping field names to their types. @@ -371,9 +475,9 @@ def _parse_field_string(field_string: str) -> Dict[str, str]: """ args = ast.parse(f"def f({field_string}): pass").body[0].args.args - names = [arg.arg for arg in args] - types = [str if arg.annotation is None else _parse_type_node(arg.annotation) for arg in args] - return zip(names, types) + field_names = [arg.arg for arg in args] + types = [str if arg.annotation is None else _parse_type_node(arg.annotation, names) for arg in args] + return zip(field_names, types) def _parse_type_node(node, names=None) -> Any: @@ -446,10 +550,16 @@ def resolve_name(type_name: str): if isinstance(node, ast.Attribute): base = _parse_type_node(node.value, names) attr_name = node.attr + if hasattr(base, attr_name): return getattr(base, attr_name) - else: - raise ValueError(f"Unknown attribute: {attr_name} on {base}") + + if isinstance(node.value, ast.Name): + full_name = f"{node.value.id}.{attr_name}" + if full_name in names: + return names[full_name] + + raise ValueError(f"Unknown attribute: {attr_name} on {base}") if isinstance(node, ast.Subscript): base_type = _parse_type_node(node.value, names) diff --git a/tests/signatures/test_custom_types.py b/tests/signatures/test_custom_types.py new file mode 100644 index 0000000000..3a43150393 --- /dev/null +++ b/tests/signatures/test_custom_types.py @@ -0,0 +1,121 @@ +import pydantic +import pytest +from typing import List, Dict, Any + +import dspy +from dspy import Signature +from dspy.utils.dummies import DummyLM + + +def test_basic_custom_type_resolution(): + """Test basic custom type resolution with both explicit and automatic mapping.""" + class CustomType(pydantic.BaseModel): + value: str + + # Custom types can be explicitly mapped + explicit_sig = Signature( + "input: CustomType -> output: str", + custom_types={"CustomType": CustomType} + ) + assert explicit_sig.input_fields["input"].annotation == CustomType + + # Custom types can also be auto-resolved from caller's scope + auto_sig = Signature("input: CustomType -> output: str") + assert auto_sig.input_fields["input"].annotation == CustomType + + +def test_type_alias_for_nested_types(): + """Test using type aliases for nested types.""" + class Container: + class NestedType(pydantic.BaseModel): + value: str + + NestedType = Container.NestedType + alias_sig = Signature("input: str -> output: NestedType") + assert alias_sig.output_fields["output"].annotation == Container.NestedType + + class Container2: + class Query(pydantic.BaseModel): + text: str + class Score(pydantic.BaseModel): + score: float + + signature = dspy.Signature("query: Container2.Query -> score: Container2.Score") + assert signature.output_fields["score"].annotation == Container2.Score + + +class GlobalCustomType(pydantic.BaseModel): + """A type defined at module level for testing module-level resolution.""" + value: str + notes: str = "" + + +def test_module_level_type_resolution(): + """Test resolution of types defined at module level.""" + # Module-level types can be auto-resolved + sig = Signature("name: str -> result: GlobalCustomType") + assert sig.output_fields["result"].annotation == GlobalCustomType + + +# Create module-level nested class for testing +class OuterContainer: + class InnerType(pydantic.BaseModel): + name: str + value: int + + +def test_recommended_patterns(): + """Test recommended patterns for working with custom types in signatures.""" + + # PATTERN 1: Local type with auto-resolution + class LocalType(pydantic.BaseModel): + value: str + + sig1 = Signature("input: str -> output: LocalType") + assert sig1.output_fields["output"].annotation == LocalType + + # PATTERN 2: Module-level type with auto-resolution + sig2 = Signature("input: str -> output: GlobalCustomType") + assert sig2.output_fields["output"].annotation == GlobalCustomType + + # PATTERN 3: Nested type with dot notation + sig3 = Signature("input: str -> output: OuterContainer.InnerType") + assert sig3.output_fields["output"].annotation == OuterContainer.InnerType + + # PATTERN 4: Nested type using alias + InnerTypeAlias = OuterContainer.InnerType + sig4 = Signature("input: str -> output: InnerTypeAlias") + assert sig4.output_fields["output"].annotation == InnerTypeAlias + + # PATTERN 5: Nested type with dot notation + sig5 = Signature("input: str -> output: OuterContainer.InnerType") + assert sig5.output_fields["output"].annotation == OuterContainer.InnerType + +def test_expected_failure(): + # InnerType DNE when not OuterContainer.InnerTypes, so this type shouldnt be resolved + with pytest.raises(ValueError): + sig4 = Signature("input: str -> output: InnerType") + assert sig4.output_fields["output"].annotation == InnerType + +def test_module_type_resolution(): + class TestModule(dspy.Module): + def __init__(self): + super().__init__() + self.predict = dspy.Predict("input: str -> output: OuterContainer.InnerType") + + def predict(self, input: str) -> str: + return input + + module = TestModule() + sig = module.predict.signature + assert sig.output_fields["output"].annotation == OuterContainer.InnerType + +def test_basic_custom_type_resolution(): + class CustomType(pydantic.BaseModel): + value: str + + sig = Signature("input: CustomType -> output: str", custom_types={"CustomType": CustomType}) + assert sig.input_fields["input"].annotation == CustomType + + sig = Signature("input: CustomType -> output: str") + assert sig.input_fields["input"].annotation == CustomType \ No newline at end of file diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index 564a749061..b1c79ae017 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -96,7 +96,8 @@ def test_signature_instructions_none(): def test_signature_from_dict(): - signature = Signature({"input1": InputField(), "input2": InputField(), "output": OutputField()}) + signature = Signature( + {"input1": InputField(), "input2": InputField(), "output": OutputField()}) for k in ["input1", "input2", "output"]: assert k in signature.fields assert signature.fields[k].annotation == str @@ -155,7 +156,8 @@ class ExampleSignature(dspy.Signature): def test_infer_prefix(): - assert infer_prefix("someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool" + assert infer_prefix( + "someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool" assert infer_prefix("version2Update") == "Version 2 Update" assert infer_prefix("modelT45Enhanced") == "Model T 45 Enhanced" assert infer_prefix("someAttributeName") == "Some Attribute Name" @@ -241,7 +243,6 @@ class CustomSignature2(dspy.Signature): def test_typed_signatures_basic_types(): - # Simple built-in types sig = Signature("input1: int, input2: str -> output: float") assert "input1" in sig.input_fields assert sig.input_fields["input1"].annotation == int @@ -252,8 +253,8 @@ def test_typed_signatures_basic_types(): def test_typed_signatures_generics(): - # More complex generic types - sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]") + sig = Signature( + "input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]") assert "input_list" in sig.input_fields assert sig.input_fields["input_list"].annotation == List[int] assert "input_dict" in sig.input_fields @@ -263,7 +264,8 @@ def test_typed_signatures_generics(): def test_typed_signatures_unions_and_optionals(): - sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]") + sig = Signature( + "input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]") assert "input_opt" in sig.input_fields # Optional[str] is actually Union[str, None] # Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct. @@ -301,8 +303,8 @@ def test_typed_signatures_any(): def test_typed_signatures_nested(): - # Nested generics and unions - sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]") + sig = Signature( + "input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]") input_nested_ann = sig.input_fields["input_nested"].annotation assert getattr(input_nested_ann, "__origin__", None) is list assert len(input_nested_ann.__args__) == 1 @@ -324,7 +326,6 @@ def test_typed_signatures_nested(): def test_typed_signatures_from_dict(): - # Creating a Signature directly from a dictionary with types fields = { "input_str_list": (List[str], InputField()), "input_dict_int": (Dict[str, int], InputField()), @@ -340,8 +341,6 @@ def test_typed_signatures_from_dict(): def test_typed_signatures_complex_combinations(): - # Test a very complex signature with multiple nested constructs - # input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]] sig = Signature( "input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]" ) @@ -366,15 +365,18 @@ def test_typed_signatures_complex_combinations(): # Expecting List[str] and Dict[str, Any] # Because sets don't preserve order, just check membership. # Find the List[str] arg - list_arg = next(a for a in possible_args if getattr(a, "__origin__", None) is list) - dict_arg = next(a for a in possible_args if getattr(a, "__origin__", None) is dict) + list_arg = next(a for a in possible_args if getattr( + a, "__origin__", None) is list) + dict_arg = next(a for a in possible_args if getattr( + a, "__origin__", None) is dict) assert list_arg.__args__ == (str,) k, v = dict_arg.__args__ assert k == str and v == Any def test_make_signature_from_string(): - sig = Signature("input1: int, input2: Dict[str, int] -> output1: List[str], output2: Union[int, str]") + sig = Signature( + "input1: int, input2: Dict[str, int] -> output1: List[str], output2: Union[int, str]") assert "input1" in sig.input_fields assert sig.input_fields["input1"].annotation == int assert "input2" in sig.input_fields @@ -401,3 +403,36 @@ class MySignature(Signature): output2_constraints = MySignature.output_fields["outputs2"].json_schema_extra["constraints"] assert "greater than or equal to: 5" in output2_constraints assert "less than or equal to: 10" in output2_constraints + + +def test_basic_custom_type(): + class CustomType(pydantic.BaseModel): + value: str + + test_signature = dspy.Signature( + "input: CustomType -> output: str", + custom_types={"CustomType": CustomType} + ) + + assert test_signature.input_fields["input"].annotation == CustomType + + lm = DummyLM([{"output": "processed"}]) + dspy.settings.configure(lm=lm) + + custom_obj = CustomType(value="test") + pred = dspy.Predict(test_signature)(input=custom_obj) + assert pred.output == "processed" + + +def test_custom_type_from_different_module(): + from pathlib import Path + + test_signature = dspy.Signature("input: Path -> output: str") + assert test_signature.input_fields["input"].annotation == Path + + lm = DummyLM([{"output": "/test/path"}]) + dspy.settings.configure(lm=lm) + + path_obj = Path("/test/path") + pred = dspy.Predict(test_signature)(input=path_obj) + assert pred.output == "/test/path" \ No newline at end of file