Skip to content

Commit 538fbb3

Browse files
Rename dspy.BaseType to dspy.Type (#8510)
* rename BaseType to Type * fix tests
1 parent b4d1a7e commit 538fbb3

File tree

9 files changed

+25
-27
lines changed

9 files changed

+25
-27
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dspy.evaluate import Evaluate # isort: skip
88
from dspy.clients import * # isort: skip
9-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, BaseType, Tool, ToolCalls # isort: skip
9+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls # isort: skip
1010
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1111
from dspy.utils.asyncify import asyncify
1212
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
5-
from dspy.adapters.types import Audio, BaseType, History, Image, Tool, ToolCalls
5+
from dspy.adapters.types import Audio, History, Image, Tool, ToolCalls, Type
66
from dspy.adapters.xml_adapter import XMLAdapter
77

88
__all__ = [
99
"Adapter",
1010
"ChatAdapter",
11-
"BaseType",
11+
"Type",
1212
"History",
1313
"Image",
1414
"Audio",

dspy/adapters/types/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dspy.adapters.types.audio import Audio
2-
from dspy.adapters.types.base_type import BaseType
2+
from dspy.adapters.types.base_type import Type
33
from dspy.adapters.types.history import History
44
from dspy.adapters.types.image import Image
55
from dspy.adapters.types.tool import Tool, ToolCalls
66

7-
__all__ = ["History", "Image", "Audio", "BaseType", "Tool", "ToolCalls"]
7+
__all__ = ["History", "Image", "Audio", "Type", "Tool", "ToolCalls"]

dspy/adapters/types/audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pydantic
88
import requests
99

10-
from dspy.adapters.types.base_type import BaseType
10+
from dspy.adapters.types.base_type import Type
1111

1212
try:
1313
import soundfile as sf
@@ -17,7 +17,7 @@
1717
SF_AVAILABLE = False
1818

1919

20-
class Audio(BaseType):
20+
class Audio(Type):
2121
data: str
2222
audio_format: str
2323

dspy/adapters/types/base_type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
CUSTOM_TYPE_END_IDENTIFIER = "<<CUSTOM-TYPE-END-IDENTIFIER>>"
1010

1111

12-
class BaseType(pydantic.BaseModel):
12+
class Type(pydantic.BaseModel):
1313
"""Base class to support creating custom types for DSPy signatures.
1414
1515
This is the parent class of DSPy custom types, e.g, dspy.Image. Subclasses must implement the `format` method to
@@ -18,7 +18,7 @@ class BaseType(pydantic.BaseModel):
1818
Example:
1919
2020
```python
21-
class Image(BaseType):
21+
class Image(Type):
2222
url: str
2323
2424
def format(self) -> list[dict[str, Any]]:
@@ -85,7 +85,7 @@ def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> li
8585
8686
This is implemented by finding the `<<CUSTOM-TYPE-START-IDENTIFIER>>` and `<<CUSTOM-TYPE-END-IDENTIFIER>>`
8787
in the user message content and splitting the content around them. The `<<CUSTOM-TYPE-START-IDENTIFIER>>`
88-
and `<<CUSTOM-TYPE-END-IDENTIFIER>>` are the reserved identifiers for the custom types as in `dspy.BaseType`.
88+
and `<<CUSTOM-TYPE-END-IDENTIFIER>>` are the reserved identifiers for the custom types as in `dspy.Type`.
8989
9090
Args:
9191
messages: a list of messages sent to the LM. The format is the same as [OpenAI API's messages

dspy/adapters/types/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pydantic
99
import requests
1010

11-
from dspy.adapters.types.base_type import BaseType
11+
from dspy.adapters.types.base_type import Type
1212

1313
try:
1414
from PIL import Image as PILImage
@@ -18,7 +18,7 @@
1818
PIL_AVAILABLE = False
1919

2020

21-
class Image(BaseType):
21+
class Image(Type):
2222
url: str
2323

2424
model_config = {

dspy/adapters/types/tool.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jsonschema import ValidationError, validate
66
from pydantic import BaseModel, TypeAdapter, create_model
77

8-
from dspy.adapters.types.base_type import BaseType
8+
from dspy.adapters.types.base_type import Type
99
from dspy.dsp.utils.settings import settings
1010
from dspy.utils.callback import with_callbacks
1111

@@ -16,7 +16,7 @@
1616
_TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
1717

1818

19-
class Tool(BaseType):
19+
class Tool(Type):
2020
"""Tool class.
2121
2222
This class is used to simplify the creation of tools for tool calling (function calling) in LLMs. Only supports
@@ -254,7 +254,7 @@ def __str__(self):
254254
return f"{self.name}{desc} {arg_desc}"
255255

256256

257-
class ToolCalls(BaseType):
257+
class ToolCalls(Type):
258258
class ToolCall(BaseModel):
259259
name: str
260260
args: dict[str, Any]
@@ -303,7 +303,8 @@ def format(self) -> list[dict[str, Any]]:
303303
"name": tool_call.name,
304304
"arguments": tool_call.args,
305305
},
306-
} for tool_call in self.tool_calls
306+
}
307+
for tool_call in self.tool_calls
307308
],
308309
}
309310
]

dspy/adapters/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import TypeAdapter
1111
from pydantic.fields import FieldInfo
1212

13-
from dspy.adapters.types.base_type import BaseType
13+
from dspy.adapters.types.base_type import Type
1414
from dspy.signatures.utils import get_dspy_field_type
1515

1616

@@ -204,7 +204,7 @@ def get_field_description_string(fields: dict) -> str:
204204
field_message += f" ({get_annotation_name(v.annotation)})"
205205
desc = v.json_schema_extra["desc"] if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
206206

207-
custom_types = BaseType.extract_custom_type_from_annotation(v.annotation)
207+
custom_types = Type.extract_custom_type_from_annotation(v.annotation)
208208
for custom_type in custom_types:
209209
if len(custom_type.description()) > 0:
210210
desc += f"\n Type description of {get_annotation_name(custom_type)}: {custom_type.description()}"

tests/adapters/test_base_type.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
21
import pydantic
32

43
import dspy
54

65

76
def test_basic_extract_custom_type_from_annotation():
8-
class Event(dspy.BaseType):
7+
class Event(dspy.Type):
98
event_name: str
109
start_date_time: str
1110
end_date_time: str | None
@@ -17,27 +16,25 @@ class ExtractEvent(dspy.Signature):
1716
email: str = dspy.InputField()
1817
event: Event = dspy.OutputField()
1918

20-
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvent.output_fields["event"].annotation) == [Event]
19+
assert dspy.Type.extract_custom_type_from_annotation(ExtractEvent.output_fields["event"].annotation) == [Event]
2120

2221
class ExtractEvents(dspy.Signature):
2322
"""Extract all events from the email content."""
2423

2524
email: str = dspy.InputField()
2625
events: list[Event] = dspy.OutputField()
2726

28-
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [
29-
Event
30-
]
27+
assert dspy.Type.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [Event]
3128

3229

3330
def test_extract_custom_type_from_annotation_with_nested_type():
34-
class Event(dspy.BaseType):
31+
class Event(dspy.Type):
3532
event_name: str
3633
start_date_time: str
3734
end_date_time: str | None
3835
location: str | None
3936

40-
class EventIdentifier(dspy.BaseType):
37+
class EventIdentifier(dspy.Type):
4138
model_config = pydantic.ConfigDict(frozen=True) # Make it hashable
4239
event_id: str
4340
event_name: str
@@ -48,7 +45,7 @@ class ExtractEvents(dspy.Signature):
4845
email: str = dspy.InputField()
4946
events: list[dict[EventIdentifier, Event]] = dspy.OutputField()
5047

51-
assert dspy.BaseType.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [
48+
assert dspy.Type.extract_custom_type_from_annotation(ExtractEvents.output_fields["events"].annotation) == [
5249
EventIdentifier,
5350
Event,
5451
]

0 commit comments

Comments
 (0)