Skip to content

Commit 82cf09c

Browse files
dspy.Audio (#8214)
* support for custom types in DSPy signatures * fix completed demos * rename custom formatting function * add dspy.Audio type * ruff fix * fix comment
1 parent e168105 commit 82cf09c

File tree

6 files changed

+156
-4
lines changed

6 files changed

+156
-4
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dspy.evaluate import Evaluate # isort: skip
1010
from dspy.clients import * # isort: skip
11-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, History, BaseType # isort: skip
11+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType # isort: skip
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
5-
from dspy.adapters.types import History, Image, BaseType
5+
from dspy.adapters.types import History, Image, Audio, BaseType
66

77
__all__ = [
88
"Adapter",
99
"ChatAdapter",
1010
"BaseType",
1111
"History",
1212
"Image",
13+
"Audio",
1314
"JSONAdapter",
1415
"TwoStepAdapter",
1516
]

dspy/adapters/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,4 +352,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
352352
Returns:
353353
A dictionary of the output fields.
354354
"""
355-
raise NotImplementedError
355+
raise NotImplementedError

dspy/adapters/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dspy.adapters.types.history import History
22
from dspy.adapters.types.image import Image
3+
from dspy.adapters.types.audio import Audio
34
from dspy.adapters.types.base_type import BaseType
45

5-
__all__ = ["History", "Image", "BaseType"]
6+
__all__ = ["History", "Image", "Audio", "BaseType"]

dspy/adapters/types/audio.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import base64
2+
import io
3+
import mimetypes
4+
import os
5+
from typing import Any, Union
6+
7+
import pydantic
8+
import requests
9+
10+
from dspy.adapters.types.base_type import BaseType
11+
12+
try:
13+
import soundfile as sf
14+
15+
SF_AVAILABLE = True
16+
except ImportError:
17+
SF_AVAILABLE = False
18+
19+
20+
class Audio(BaseType):
21+
data: str
22+
format: str
23+
24+
model_config = {
25+
"frozen": True,
26+
"extra": "forbid",
27+
}
28+
29+
def format(self) -> Union[list[dict[str, Any]], str]:
30+
try:
31+
data = self.data
32+
except Exception as e:
33+
raise ValueError(f"Failed to format audio for DSPy: {e}")
34+
return [{
35+
"type": "input_audio",
36+
"input_audio": {
37+
"data": data,
38+
"format": self.format
39+
}
40+
}]
41+
42+
43+
@pydantic.model_validator(mode="before")
44+
@classmethod
45+
def validate_input(cls, values: Any) -> Any:
46+
"""
47+
Validate input for Audio, expecting 'data' and 'format' keys in dictionary.
48+
"""
49+
if isinstance(values, cls):
50+
return {"data": values.data, "format": values.format}
51+
return encode_audio(values)
52+
53+
@classmethod
54+
def from_url(cls, url: str) -> "Audio":
55+
"""
56+
Download an audio file from URL and encode it as base64.
57+
"""
58+
response = requests.get(url)
59+
response.raise_for_status()
60+
mime_type = response.headers.get("Content-Type", "audio/wav")
61+
if not mime_type.startswith("audio/"):
62+
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")
63+
audio_format = mime_type.split("/")[1]
64+
encoded_data = base64.b64encode(response.content).decode("utf-8")
65+
return cls(data=encoded_data, format=audio_format)
66+
67+
@classmethod
68+
def from_file(cls, file_path: str) -> "Audio":
69+
"""
70+
Read local audio file and encode it as base64.
71+
"""
72+
if not os.path.isfile(file_path):
73+
raise ValueError(f"File not found: {file_path}")
74+
75+
mime_type, _ = mimetypes.guess_type(file_path)
76+
if not mime_type or not mime_type.startswith("audio/"):
77+
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")
78+
79+
with open(file_path, "rb") as file:
80+
file_data = file.read()
81+
82+
audio_format = mime_type.split("/")[1]
83+
encoded_data = base64.b64encode(file_data).decode("utf-8")
84+
return cls(data=encoded_data, format=audio_format)
85+
86+
@classmethod
87+
def from_array(
88+
cls, array: Any, sampling_rate: int, format: str = "wav"
89+
) -> "Audio":
90+
"""
91+
Process numpy-like array and encode it as base64. Uses sampling rate and audio format for encoding.
92+
"""
93+
if not SF_AVAILABLE:
94+
raise ImportError("soundfile is required to process audio arrays.")
95+
96+
byte_buffer = io.BytesIO()
97+
sf.write(
98+
byte_buffer,
99+
array,
100+
sampling_rate,
101+
format=format.upper(),
102+
subtype="PCM_16",
103+
)
104+
encoded_data = base64.b64encode(byte_buffer.getvalue()).decode("utf-8")
105+
return cls(data=encoded_data, format=format)
106+
107+
def __str__(self) -> str:
108+
return self.serialize_model()
109+
110+
def __repr__(self) -> str:
111+
length = len(self.data)
112+
return f"Audio(data=<AUDIO_BASE_64_ENCODED({length})>, format='{self.format}')"
113+
114+
def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: int = 16000, format: str = "wav") -> dict:
115+
"""
116+
Encode audio to a dict with 'data' and 'format'.
117+
118+
Accepts: local file path, URL, data URI, dict, Audio instance, numpy array, or bytes (with known format).
119+
"""
120+
if isinstance(audio, dict) and "data" in audio and "format" in audio:
121+
return audio
122+
elif isinstance(audio, Audio):
123+
return {"data": audio.data, "format": audio.format}
124+
elif isinstance(audio, str) and audio.startswith("data:audio/"):
125+
try:
126+
header, b64data = audio.split(",", 1)
127+
mime = header.split(";")[0].split(":")[1]
128+
audio_format = mime.split("/")[1]
129+
return {"data": b64data, "format": audio_format}
130+
except Exception as e:
131+
raise ValueError(f"Malformed audio data URI: {e}")
132+
elif isinstance(audio, str) and os.path.isfile(audio):
133+
a = Audio.from_file(audio)
134+
return {"data": a.data, "format": a.format}
135+
elif isinstance(audio, str) and audio.startswith("http"):
136+
a = Audio.from_url(audio)
137+
return {"data": a.data, "format": a.format}
138+
elif SF_AVAILABLE and hasattr(audio, "shape"):
139+
a = Audio.from_array(audio, sampling_rate=sampling_rate, format=format)
140+
return {"data": a.data, "format": a.format}
141+
elif isinstance(audio, bytes):
142+
encoded = base64.b64encode(audio).decode("utf-8")
143+
return {"data": encoded, "format": format}
144+
else:
145+
raise ValueError(f"Unsupported type for encode_audio: {type(audio)}")

dspy/clients/base_lm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def _inspect_history(history, n: int = 1):
184184
else:
185185
image_str = f"<image_url: {c['image_url']['url']}>"
186186
print(_blue(image_str.strip()))
187+
elif c["type"] == "input_audio":
188+
audio_format = c["input_audio"]["format"]
189+
len_audio = len(c["input_audio"]["data"])
190+
audio_str = f"<audio format='{audio_format}' base64-encoded, length={len_audio}>"
191+
print(_blue(audio_str.strip()))
187192
print("\n")
188193

189194
print(_red("Response:"))

0 commit comments

Comments
 (0)