Skip to content

Commit a8b4ea3

Browse files
authored
Add testing for PEP604 union types (#8475)
* Add testing for PEP604 union types * Minor PR fixes for pep604 tests * Remove extra newline
1 parent 252191d commit a8b4ea3

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

tests/signatures/test_signature.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from types import UnionType
12
from typing import Any, Dict, List, Optional, Tuple, Union
23

34
import pydantic
@@ -436,3 +437,139 @@ def test_custom_type_from_different_module():
436437
path_obj = Path("/test/path")
437438
pred = dspy.Predict(test_signature)(input=path_obj)
438439
assert pred.output == "/test/path"
440+
441+
def test_pep604_union_type_inline():
442+
sig = Signature(
443+
"input1: str | None, input2: None | int -> output_union: int | str"
444+
)
445+
446+
# input1 and input2 test that both 'T | None' and 'None | T' are interpreted as Optional types,
447+
# regardless of the order of None in the union expression.
448+
449+
assert "input1" in sig.input_fields
450+
input1_annotation = sig.input_fields["input1"].annotation
451+
assert input1_annotation == Optional[str] or (
452+
getattr(input1_annotation, "__origin__", None) is Union
453+
and str in input1_annotation.__args__
454+
and type(None) in input1_annotation.__args__
455+
)
456+
457+
assert "input2" in sig.input_fields
458+
input2_annotation = sig.input_fields["input2"].annotation
459+
assert input2_annotation == Optional[int] or (
460+
getattr(input2_annotation, "__origin__", None) is Union
461+
and int in input2_annotation.__args__
462+
and type(None) in input2_annotation.__args__
463+
)
464+
465+
assert "output_union" in sig.output_fields
466+
output_union_annotation = sig.output_fields["output_union"].annotation
467+
assert (
468+
getattr(output_union_annotation, "__origin__", None) is Union
469+
and int in output_union_annotation.__args__
470+
and str in output_union_annotation.__args__
471+
)
472+
473+
474+
def test_pep604_union_type_inline_equivalence():
475+
sig1 = Signature("input: str | None -> output: int | str")
476+
sig2 = Signature("input: Optional[str] -> output: Union[int, str]")
477+
478+
# PEP 604 union types in inline signatures should be equivalent to Optional and Union types
479+
assert sig1.equals(sig2)
480+
481+
# Check that the annotations are equivalent
482+
assert sig1.input_fields["input"].annotation == sig2.input_fields["input"].annotation
483+
assert sig1.output_fields["output"].annotation == sig2.output_fields["output"].annotation
484+
485+
486+
def test_pep604_union_type_inline_nested():
487+
sig = Signature(
488+
"input: str | (int | float) | None -> output: str"
489+
)
490+
assert "input" in sig.input_fields
491+
input_annotation = sig.input_fields["input"].annotation
492+
493+
# Check for the correct union: Union[str, int, float, NoneType]
494+
assert getattr(input_annotation, "__origin__", None) is Union
495+
assert set(input_annotation.__args__) == {str, int, float, type(None)}
496+
497+
498+
def test_pep604_union_type_class_nested():
499+
class Sig1(Signature):
500+
input: str | (int | float) | None = InputField()
501+
output: str = OutputField()
502+
503+
assert "input" in Sig1.input_fields
504+
input_annotation = Sig1.input_fields["input"].annotation
505+
506+
# Check for the correct union: UnionType[str, int, float, NoneType]
507+
assert isinstance(input_annotation, UnionType)
508+
assert set(input_annotation.__args__) == {str, int, float, type(None)}
509+
510+
511+
def test_pep604_union_type_class_equivalence():
512+
class Sig1(Signature):
513+
input: str | None = InputField()
514+
output: int | str = OutputField()
515+
516+
class Sig2(Signature):
517+
input: Optional[str] = InputField() # noqa: UP045
518+
output: Union[int, str] = OutputField() # noqa: UP007
519+
520+
# PEP 604 union types in class signatures should be equivalent to Optional and Union types
521+
assert Sig1.equals(Sig2)
522+
523+
# Check that the annotations are equivalent
524+
assert Sig1.input_fields["input"].annotation == Sig2.input_fields["input"].annotation
525+
assert Sig1.output_fields["output"].annotation == Sig2.output_fields["output"].annotation
526+
527+
# Check that the pep604 annotations are of type UnionType
528+
assert isinstance(Sig1.input_fields["input"].annotation, UnionType)
529+
assert isinstance(Sig1.output_fields["output"].annotation, UnionType)
530+
531+
532+
def test_pep604_union_type_insert():
533+
class PEP604Signature(Signature):
534+
input: str | None = InputField()
535+
output: int | str = OutputField()
536+
537+
# This test ensures that inserting a field into a signature with a PEP 604 UnionType works
538+
539+
# Insert a new input field at the start
540+
NewSig = PEP604Signature.prepend("new_input", InputField(), float | int)
541+
assert "new_input" in NewSig.input_fields
542+
543+
new_input_annotation = NewSig.input_fields["new_input"].annotation
544+
assert isinstance(new_input_annotation, UnionType)
545+
assert set(new_input_annotation.__args__) == {float, int}
546+
547+
# The original union type field should still be present and correct
548+
input_annotation = NewSig.input_fields["input"].annotation
549+
output_annotation = NewSig.output_fields["output"].annotation
550+
551+
assert isinstance(input_annotation, UnionType)
552+
assert str in input_annotation.__args__ and type(None) in input_annotation.__args__
553+
554+
assert isinstance(output_annotation, UnionType)
555+
assert set(output_annotation.__args__) == {int, str}
556+
557+
558+
def test_pep604_union_type_with_custom_types():
559+
class CustomType(pydantic.BaseModel):
560+
value: str
561+
562+
sig = Signature(
563+
"input: CustomType | None -> output: int | str",
564+
custom_types={"CustomType": CustomType}
565+
)
566+
567+
assert sig.input_fields["input"].annotation == Union[CustomType, None]
568+
assert sig.output_fields["output"].annotation == Union[int, str]
569+
570+
lm = DummyLM([{"output": "processed"}])
571+
dspy.settings.configure(lm=lm)
572+
573+
custom_obj = CustomType(value="test")
574+
pred = dspy.Predict(sig)(input=custom_obj)
575+
assert pred.output == "processed"

0 commit comments

Comments
 (0)