Skip to content

Commit baae93e

Browse files
refactoring audio_format completely (#8349)
1 parent e4ec979 commit baae93e

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

dspy/adapters/types/audio.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def format(self) -> list[dict[str, Any]]:
4444
@classmethod
4545
def validate_input(cls, values: Any) -> Any:
4646
"""
47-
Validate input for Audio, expecting 'data' and 'format' keys in dictionary.
47+
Validate input for Audio, expecting 'data' and 'audio_format' keys in dictionary.
4848
"""
4949
if isinstance(values, cls):
50-
return {"data": values.data, "format": values.format}
50+
return {"data": values.data, "audio_format": values.audio_format}
5151
return encode_audio(values)
5252

5353
@classmethod
@@ -62,7 +62,7 @@ def from_url(cls, url: str) -> "Audio":
6262
raise ValueError(f"Unsupported MIME type for audio: {mime_type}")
6363
audio_format = mime_type.split("/")[1]
6464
encoded_data = base64.b64encode(response.content).decode("utf-8")
65-
return cls(data=encoded_data, format=audio_format)
65+
return cls(data=encoded_data, audio_format=audio_format)
6666

6767
@classmethod
6868
def from_file(cls, file_path: str) -> "Audio":
@@ -81,7 +81,7 @@ def from_file(cls, file_path: str) -> "Audio":
8181

8282
audio_format = mime_type.split("/")[1]
8383
encoded_data = base64.b64encode(file_data).decode("utf-8")
84-
return cls(data=encoded_data, format=audio_format)
84+
return cls(data=encoded_data, audio_format=audio_format)
8585

8686
@classmethod
8787
def from_array(
@@ -102,44 +102,44 @@ def from_array(
102102
subtype="PCM_16",
103103
)
104104
encoded_data = base64.b64encode(byte_buffer.getvalue()).decode("utf-8")
105-
return cls(data=encoded_data, format=format)
105+
return cls(data=encoded_data, audio_format=format)
106106

107107
def __str__(self) -> str:
108108
return self.serialize_model()
109109

110110
def __repr__(self) -> str:
111111
length = len(self.data)
112-
return f"Audio(data=<AUDIO_BASE_64_ENCODED({length})>, format='{self.audio_format}')"
112+
return f"Audio(data=<AUDIO_BASE_64_ENCODED({length})>, audio_format='{self.audio_format}')"
113113

114114
def encode_audio(audio: Union[str, bytes, dict, "Audio", Any], sampling_rate: int = 16000, format: str = "wav") -> dict:
115115
"""
116-
Encode audio to a dict with 'data' and 'format'.
116+
Encode audio to a dict with 'data' and 'audio_format'.
117117
118118
Accepts: local file path, URL, data URI, dict, Audio instance, numpy array, or bytes (with known format).
119119
"""
120-
if isinstance(audio, dict) and "data" in audio and "format" in audio:
120+
if isinstance(audio, dict) and "data" in audio and "audio_format" in audio:
121121
return audio
122122
elif isinstance(audio, Audio):
123-
return {"data": audio.data, "format": audio.format}
123+
return {"data": audio.data, "audio_format": audio.audio_format}
124124
elif isinstance(audio, str) and audio.startswith("data:audio/"):
125125
try:
126126
header, b64data = audio.split(",", 1)
127127
mime = header.split(";")[0].split(":")[1]
128128
audio_format = mime.split("/")[1]
129-
return {"data": b64data, "format": audio_format}
129+
return {"data": b64data, "audio_format": audio_format}
130130
except Exception as e:
131131
raise ValueError(f"Malformed audio data URI: {e}")
132132
elif isinstance(audio, str) and os.path.isfile(audio):
133133
a = Audio.from_file(audio)
134-
return {"data": a.data, "format": a.format}
134+
return {"data": a.data, "audio_format": a.audio_format}
135135
elif isinstance(audio, str) and audio.startswith("http"):
136136
a = Audio.from_url(audio)
137-
return {"data": a.data, "format": a.format}
137+
return {"data": a.data, "audio_format": a.audio_format}
138138
elif SF_AVAILABLE and hasattr(audio, "shape"):
139139
a = Audio.from_array(audio, sampling_rate=sampling_rate, format=format)
140-
return {"data": a.data, "format": a.format}
140+
return {"data": a.data, "audio_format": a.audio_format}
141141
elif isinstance(audio, bytes):
142142
encoded = base64.b64encode(audio).decode("utf-8")
143-
return {"data": encoded, "format": format}
143+
return {"data": encoded, "audio_format": format}
144144
else:
145145
raise ValueError(f"Unsupported type for encode_audio: {type(audio)}")

0 commit comments

Comments
 (0)