1
+ import base64
2
+ import io
3
+ import mimetypes
4
+ import os
5
+ from typing import Any , Union
6
+
7
+ import pydantic
8
+ import requests
9
+
10
+ from dspy .adapters .types .base_type import BaseType
11
+
12
+ try :
13
+ import soundfile as sf
14
+
15
+ SF_AVAILABLE = True
16
+ except ImportError :
17
+ SF_AVAILABLE = False
18
+
19
+
20
+ class Audio (BaseType ):
21
+ data : str
22
+ format : str
23
+
24
+ model_config = {
25
+ "frozen" : True ,
26
+ "extra" : "forbid" ,
27
+ }
28
+
29
+ def format (self ) -> Union [list [dict [str , Any ]], str ]:
30
+ try :
31
+ data = self .data
32
+ except Exception as e :
33
+ raise ValueError (f"Failed to format audio for DSPy: { e } " )
34
+ return [{
35
+ "type" : "input_audio" ,
36
+ "input_audio" : {
37
+ "data" : data ,
38
+ "format" : self .format
39
+ }
40
+ }]
41
+
42
+
43
+ @pydantic .model_validator (mode = "before" )
44
+ @classmethod
45
+ def validate_input (cls , values : Any ) -> Any :
46
+ """
47
+ Validate input for Audio, expecting 'data' and 'format' keys in dictionary.
48
+ """
49
+ if isinstance (values , cls ):
50
+ return {"data" : values .data , "format" : values .format }
51
+ return encode_audio (values )
52
+
53
+ @classmethod
54
+ def from_url (cls , url : str ) -> "Audio" :
55
+ """
56
+ Download an audio file from URL and encode it as base64.
57
+ """
58
+ response = requests .get (url )
59
+ response .raise_for_status ()
60
+ mime_type = response .headers .get ("Content-Type" , "audio/wav" )
61
+ if not mime_type .startswith ("audio/" ):
62
+ raise ValueError (f"Unsupported MIME type for audio: { mime_type } " )
63
+ audio_format = mime_type .split ("/" )[1 ]
64
+ encoded_data = base64 .b64encode (response .content ).decode ("utf-8" )
65
+ return cls (data = encoded_data , format = audio_format )
66
+
67
+ @classmethod
68
+ def from_file (cls , file_path : str ) -> "Audio" :
69
+ """
70
+ Read local audio file and encode it as base64.
71
+ """
72
+ if not os .path .isfile (file_path ):
73
+ raise ValueError (f"File not found: { file_path } " )
74
+
75
+ mime_type , _ = mimetypes .guess_type (file_path )
76
+ if not mime_type or not mime_type .startswith ("audio/" ):
77
+ raise ValueError (f"Unsupported MIME type for audio: { mime_type } " )
78
+
79
+ with open (file_path , "rb" ) as file :
80
+ file_data = file .read ()
81
+
82
+ audio_format = mime_type .split ("/" )[1 ]
83
+ encoded_data = base64 .b64encode (file_data ).decode ("utf-8" )
84
+ return cls (data = encoded_data , format = audio_format )
85
+
86
+ @classmethod
87
+ def from_array (
88
+ cls , array : Any , sampling_rate : int , format : str = "wav"
89
+ ) -> "Audio" :
90
+ """
91
+ Process numpy-like array and encode it as base64. Uses sampling rate and audio format for encoding.
92
+ """
93
+ if not SF_AVAILABLE :
94
+ raise ImportError ("soundfile is required to process audio arrays." )
95
+
96
+ byte_buffer = io .BytesIO ()
97
+ sf .write (
98
+ byte_buffer ,
99
+ array ,
100
+ sampling_rate ,
101
+ format = format .upper (),
102
+ subtype = "PCM_16" ,
103
+ )
104
+ encoded_data = base64 .b64encode (byte_buffer .getvalue ()).decode ("utf-8" )
105
+ return cls (data = encoded_data , format = format )
106
+
107
+ def __str__ (self ) -> str :
108
+ return self .serialize_model ()
109
+
110
+ def __repr__ (self ) -> str :
111
+ length = len (self .data )
112
+ return f"Audio(data=<AUDIO_BASE_64_ENCODED({ length } )>, format='{ self .format } ')"
113
+
114
+ def encode_audio (audio : Union [str , bytes , dict , "Audio" , Any ], sampling_rate : int = 16000 , format : str = "wav" ) -> dict :
115
+ """
116
+ Encode audio to a dict with 'data' and 'format'.
117
+
118
+ Accepts: local file path, URL, data URI, dict, Audio instance, numpy array, or bytes (with known format).
119
+ """
120
+ if isinstance (audio , dict ) and "data" in audio and "format" in audio :
121
+ return audio
122
+ elif isinstance (audio , Audio ):
123
+ return {"data" : audio .data , "format" : audio .format }
124
+ elif isinstance (audio , str ) and audio .startswith ("data:audio/" ):
125
+ try :
126
+ header , b64data = audio .split ("," , 1 )
127
+ mime = header .split (";" )[0 ].split (":" )[1 ]
128
+ audio_format = mime .split ("/" )[1 ]
129
+ return {"data" : b64data , "format" : audio_format }
130
+ except Exception as e :
131
+ raise ValueError (f"Malformed audio data URI: { e } " )
132
+ elif isinstance (audio , str ) and os .path .isfile (audio ):
133
+ a = Audio .from_file (audio )
134
+ return {"data" : a .data , "format" : a .format }
135
+ elif isinstance (audio , str ) and audio .startswith ("http" ):
136
+ a = Audio .from_url (audio )
137
+ return {"data" : a .data , "format" : a .format }
138
+ elif SF_AVAILABLE and hasattr (audio , "shape" ):
139
+ a = Audio .from_array (audio , sampling_rate = sampling_rate , format = format )
140
+ return {"data" : a .data , "format" : a .format }
141
+ elif isinstance (audio , bytes ):
142
+ encoded = base64 .b64encode (audio ).decode ("utf-8" )
143
+ return {"data" : encoded , "format" : format }
144
+ else :
145
+ raise ValueError (f"Unsupported type for encode_audio: { type (audio )} " )
0 commit comments