|
| 1 | +import dspy |
| 2 | +from typing import Optional |
| 3 | +import pydantic |
| 4 | + |
| 5 | + |
| 6 | +def test_basic_extract_custom_type_from_annotation(): |
| 7 | + class Event(dspy.BaseType): |
| 8 | + event_name: str |
| 9 | + start_date_time: str |
| 10 | + end_date_time: Optional[str] |
| 11 | + location: Optional[str] |
| 12 | + |
| 13 | + class ExtractEvent(dspy.Signature): |
| 14 | + """Extract all events from the email content.""" |
| 15 | + |
| 16 | + email: str = dspy.InputField() |
| 17 | + event: Event = dspy.OutputField() |
| 18 | + |
| 19 | + assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvent.output_fields["event"].annotation) == [Event] |
| 20 | + |
| 21 | + class ExtractEvents(dspy.Signature): |
| 22 | + """Extract all events from the email content.""" |
| 23 | + |
| 24 | + email: str = dspy.InputField() |
| 25 | + events: list[Event] = dspy.OutputField() |
| 26 | + |
| 27 | + assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [ |
| 28 | + Event |
| 29 | + ] |
| 30 | + |
| 31 | + |
| 32 | +def test_extract_custom_type_from_annotation_with_nested_type(): |
| 33 | + class Event(dspy.BaseType): |
| 34 | + event_name: str |
| 35 | + start_date_time: str |
| 36 | + end_date_time: Optional[str] |
| 37 | + location: Optional[str] |
| 38 | + |
| 39 | + class EventIdentifier(dspy.BaseType): |
| 40 | + model_config = pydantic.ConfigDict(frozen=True) # Make it hashable |
| 41 | + event_id: str |
| 42 | + event_name: str |
| 43 | + |
| 44 | + class ExtractEvents(dspy.Signature): |
| 45 | + """Extract all events from the email content.""" |
| 46 | + |
| 47 | + email: str = dspy.InputField() |
| 48 | + events: list[dict[EventIdentifier, Event]] = dspy.OutputField() |
| 49 | + |
| 50 | + assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [ |
| 51 | + EventIdentifier, |
| 52 | + Event, |
| 53 | + ] |
0 commit comments