|
| 1 | +from types import UnionType |
1 | 2 | from typing import Any, Dict, List, Optional, Tuple, Union
|
2 | 3 |
|
3 | 4 | import pydantic
|
@@ -436,3 +437,139 @@ def test_custom_type_from_different_module():
|
436 | 437 | path_obj = Path("/test/path")
|
437 | 438 | pred = dspy.Predict(test_signature)(input=path_obj)
|
438 | 439 | assert pred.output == "/test/path"
|
| 440 | + |
| 441 | +def test_pep604_union_type_inline(): |
| 442 | + sig = Signature( |
| 443 | + "input1: str | None, input2: None | int -> output_union: int | str" |
| 444 | + ) |
| 445 | + |
| 446 | + # input1 and input2 test that both 'T | None' and 'None | T' are interpreted as Optional types, |
| 447 | + # regardless of the order of None in the union expression. |
| 448 | + |
| 449 | + assert "input1" in sig.input_fields |
| 450 | + input1_annotation = sig.input_fields["input1"].annotation |
| 451 | + assert input1_annotation == Optional[str] or ( |
| 452 | + getattr(input1_annotation, "__origin__", None) is Union |
| 453 | + and str in input1_annotation.__args__ |
| 454 | + and type(None) in input1_annotation.__args__ |
| 455 | + ) |
| 456 | + |
| 457 | + assert "input2" in sig.input_fields |
| 458 | + input2_annotation = sig.input_fields["input2"].annotation |
| 459 | + assert input2_annotation == Optional[int] or ( |
| 460 | + getattr(input2_annotation, "__origin__", None) is Union |
| 461 | + and int in input2_annotation.__args__ |
| 462 | + and type(None) in input2_annotation.__args__ |
| 463 | + ) |
| 464 | + |
| 465 | + assert "output_union" in sig.output_fields |
| 466 | + output_union_annotation = sig.output_fields["output_union"].annotation |
| 467 | + assert ( |
| 468 | + getattr(output_union_annotation, "__origin__", None) is Union |
| 469 | + and int in output_union_annotation.__args__ |
| 470 | + and str in output_union_annotation.__args__ |
| 471 | + ) |
| 472 | + |
| 473 | + |
| 474 | +def test_pep604_union_type_inline_equivalence(): |
| 475 | + sig1 = Signature("input: str | None -> output: int | str") |
| 476 | + sig2 = Signature("input: Optional[str] -> output: Union[int, str]") |
| 477 | + |
| 478 | + # PEP 604 union types in inline signatures should be equivalent to Optional and Union types |
| 479 | + assert sig1.equals(sig2) |
| 480 | + |
| 481 | + # Check that the annotations are equivalent |
| 482 | + assert sig1.input_fields["input"].annotation == sig2.input_fields["input"].annotation |
| 483 | + assert sig1.output_fields["output"].annotation == sig2.output_fields["output"].annotation |
| 484 | + |
| 485 | + |
| 486 | +def test_pep604_union_type_inline_nested(): |
| 487 | + sig = Signature( |
| 488 | + "input: str | (int | float) | None -> output: str" |
| 489 | + ) |
| 490 | + assert "input" in sig.input_fields |
| 491 | + input_annotation = sig.input_fields["input"].annotation |
| 492 | + |
| 493 | + # Check for the correct union: Union[str, int, float, NoneType] |
| 494 | + assert getattr(input_annotation, "__origin__", None) is Union |
| 495 | + assert set(input_annotation.__args__) == {str, int, float, type(None)} |
| 496 | + |
| 497 | + |
| 498 | +def test_pep604_union_type_class_nested(): |
| 499 | + class Sig1(Signature): |
| 500 | + input: str | (int | float) | None = InputField() |
| 501 | + output: str = OutputField() |
| 502 | + |
| 503 | + assert "input" in Sig1.input_fields |
| 504 | + input_annotation = Sig1.input_fields["input"].annotation |
| 505 | + |
| 506 | + # Check for the correct union: UnionType[str, int, float, NoneType] |
| 507 | + assert isinstance(input_annotation, UnionType) |
| 508 | + assert set(input_annotation.__args__) == {str, int, float, type(None)} |
| 509 | + |
| 510 | + |
| 511 | +def test_pep604_union_type_class_equivalence(): |
| 512 | + class Sig1(Signature): |
| 513 | + input: str | None = InputField() |
| 514 | + output: int | str = OutputField() |
| 515 | + |
| 516 | + class Sig2(Signature): |
| 517 | + input: Optional[str] = InputField() # noqa: UP045 |
| 518 | + output: Union[int, str] = OutputField() # noqa: UP007 |
| 519 | + |
| 520 | + # PEP 604 union types in class signatures should be equivalent to Optional and Union types |
| 521 | + assert Sig1.equals(Sig2) |
| 522 | + |
| 523 | + # Check that the annotations are equivalent |
| 524 | + assert Sig1.input_fields["input"].annotation == Sig2.input_fields["input"].annotation |
| 525 | + assert Sig1.output_fields["output"].annotation == Sig2.output_fields["output"].annotation |
| 526 | + |
| 527 | + # Check that the pep604 annotations are of type UnionType |
| 528 | + assert isinstance(Sig1.input_fields["input"].annotation, UnionType) |
| 529 | + assert isinstance(Sig1.output_fields["output"].annotation, UnionType) |
| 530 | + |
| 531 | + |
| 532 | +def test_pep604_union_type_insert(): |
| 533 | + class PEP604Signature(Signature): |
| 534 | + input: str | None = InputField() |
| 535 | + output: int | str = OutputField() |
| 536 | + |
| 537 | + # This test ensures that inserting a field into a signature with a PEP 604 UnionType works |
| 538 | + |
| 539 | + # Insert a new input field at the start |
| 540 | + NewSig = PEP604Signature.prepend("new_input", InputField(), float | int) |
| 541 | + assert "new_input" in NewSig.input_fields |
| 542 | + |
| 543 | + new_input_annotation = NewSig.input_fields["new_input"].annotation |
| 544 | + assert isinstance(new_input_annotation, UnionType) |
| 545 | + assert set(new_input_annotation.__args__) == {float, int} |
| 546 | + |
| 547 | + # The original union type field should still be present and correct |
| 548 | + input_annotation = NewSig.input_fields["input"].annotation |
| 549 | + output_annotation = NewSig.output_fields["output"].annotation |
| 550 | + |
| 551 | + assert isinstance(input_annotation, UnionType) |
| 552 | + assert str in input_annotation.__args__ and type(None) in input_annotation.__args__ |
| 553 | + |
| 554 | + assert isinstance(output_annotation, UnionType) |
| 555 | + assert set(output_annotation.__args__) == {int, str} |
| 556 | + |
| 557 | + |
| 558 | +def test_pep604_union_type_with_custom_types(): |
| 559 | + class CustomType(pydantic.BaseModel): |
| 560 | + value: str |
| 561 | + |
| 562 | + sig = Signature( |
| 563 | + "input: CustomType | None -> output: int | str", |
| 564 | + custom_types={"CustomType": CustomType} |
| 565 | + ) |
| 566 | + |
| 567 | + assert sig.input_fields["input"].annotation == Union[CustomType, None] |
| 568 | + assert sig.output_fields["output"].annotation == Union[int, str] |
| 569 | + |
| 570 | + lm = DummyLM([{"output": "processed"}]) |
| 571 | + dspy.settings.configure(lm=lm) |
| 572 | + |
| 573 | + custom_obj = CustomType(value="test") |
| 574 | + pred = dspy.Predict(sig)(input=custom_obj) |
| 575 | + assert pred.output == "processed" |
0 commit comments