Skip to content

dspy.Audio #8214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dspy.evaluate import Evaluate # isort: skip
from dspy.clients import * # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, History, BaseType # isort: skip
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, TwoStepAdapter, Image, Audio, History, BaseType # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
Expand Down
3 changes: 2 additions & 1 deletion dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.two_step_adapter import TwoStepAdapter
from dspy.adapters.types import History, Image, BaseType
from dspy.adapters.types import History, Image, Audio, BaseType

__all__ = [
"Adapter",
"ChatAdapter",
"BaseType",
"History",
"Image",
"Audio",
"JSONAdapter",
"TwoStepAdapter",
]
2 changes: 1 addition & 1 deletion dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,4 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
Returns:
A dictionary of the output fields.
"""
raise NotImplementedError
raise NotImplementedError
3 changes: 2 additions & 1 deletion dspy/adapters/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dspy.adapters.types.history import History
from dspy.adapters.types.image import Image
from dspy.adapters.types.audio import Audio
from dspy.adapters.types.base_type import BaseType

__all__ = ["History", "Image", "BaseType"]
__all__ = ["History", "Image", "Audio", "BaseType"]
145 changes: 145 additions & 0 deletions dspy/adapters/types/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import base64
import io
import mimetypes
import os
from typing import Any, Union

import pydantic
import requests

from dspy.adapters.types.base_type import BaseType

try:
import soundfile as sf

SF_AVAILABLE = True
except ImportError:
SF_AVAILABLE = False


class Audio(BaseType):
data: str
format: str

model_config = {
"frozen": True,
"extra": "forbid",
}

def format(self) -> Union[list[dict[str, Any]], str]:
try:
data = self.data
except Exception as e:
raise ValueError(f"Failed to format audio for DSPy: {e}")
return [{
"type": "input_audio",
"input_audio": {
"data": data,
"format": self.format
}
}]


@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values: Any) -> Any:
"""
Validate input for Audio, expecting 'data' and 'format' keys in dictionary.
"""
if isinstance(values, cls):
return {"data": values.data, "format": values.format}
return encode_audio(values)

@classmethod
def from_url(cls, url: str) -> "Audio":
"""
Download an audio file from URL and encode it as base64.
"""
response = requests.get(url)
response.raise_for_status()
mime_type = response.headers.get("Content-Type", "audio/wav")
if not mime_type.startswith("audio/"):
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")
audio_format = mime_type.split("/")[1]
encoded_data = base64.b64encode(response.content).decode("utf-8")
return cls(data=encoded_data, format=audio_format)

@classmethod
def from_file(cls, file_path: str) -> "Audio":
"""
Read local audio file and encode it as base64.
"""
if not os.path.isfile(file_path):
raise ValueError(f"File not found: {file_path}")

mime_type, _ = mimetypes.guess_type(file_path)
if not mime_type or not mime_type.startswith("audio/"):
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")

with open(file_path, "rb") as file:
file_data = file.read()

audio_format = mime_type.split("/")[1]
encoded_data = base64.b64encode(file_data).decode("utf-8")
return cls(data=encoded_data, format=audio_format)

@classmethod
def from_array(
cls, array: Any, sampling_rate: int, format: str = "wav"
) -> "Audio":
"""
Process numpy-like array and encode it as base64. Uses sampling rate and audio format for encoding.
"""
if not SF_AVAILABLE:
raise ImportError("soundfile is required to process audio arrays.")

byte_buffer = io.BytesIO()
sf.write(
byte_buffer,
array,
sampling_rate,
format=format.upper(),
subtype="PCM_16",
)
encoded_data = base64.b64encode(byte_buffer.getvalue()).decode("utf-8")
return cls(data=encoded_data, format=format)

def __str__(self) -> str:
return self.serialize_model()

def __repr__(self) -> str:
length = len(self.data)
return f"Audio(data=<AUDIO_BASE_64_ENCODED({length})>, format='{self.format}')"

def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: int = 16000, format: str = "wav") -> dict:
"""
Encode audio to a dict with 'data' and 'format'.

Accepts: local file path, URL, data URI, dict, Audio instance, numpy array, or bytes (with known format).
"""
if isinstance(audio, dict) and "data" in audio and "format" in audio:
return audio
elif isinstance(audio, Audio):
return {"data": audio.data, "format": audio.format}
elif isinstance(audio, str) and audio.startswith("data:audio/"):
try:
header, b64data = audio.split(",", 1)
mime = header.split(";")[0].split(":")[1]
audio_format = mime.split("/")[1]
return {"data": b64data, "format": audio_format}
except Exception as e:
raise ValueError(f"Malformed audio data URI: {e}")
elif isinstance(audio, str) and os.path.isfile(audio):
a = Audio.from_file(audio)
return {"data": a.data, "format": a.format}
elif isinstance(audio, str) and audio.startswith("http"):
a = Audio.from_url(audio)
return {"data": a.data, "format": a.format}
elif SF_AVAILABLE and hasattr(audio, "shape"):
a = Audio.from_array(audio, sampling_rate=sampling_rate, format=format)
return {"data": a.data, "format": a.format}
elif isinstance(audio, bytes):
encoded = base64.b64encode(audio).decode("utf-8")
return {"data": encoded, "format": format}
else:
raise ValueError(f"Unsupported type for encode_audio: {type(audio)}")
5 changes: 5 additions & 0 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def _inspect_history(history, n: int = 1):
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
elif c["type"] == "input_audio":
audio_format = c["input_audio"]["format"]
len_audio = len(c["input_audio"]["data"])
audio_str = f"<audio format='{audio_format}' base64-encoded, length={len_audio}>"
print(_blue(audio_str.strip()))
print("\n")

print(_red("Response:"))
Expand Down