Skip to content

Commit a8d5b8e

Browse files
committed
feat: implement llm.Audio
1 parent 11e4866 commit a8d5b8e

File tree

2 files changed

+424
-22
lines changed

2 files changed

+424
-22
lines changed
Lines changed: 159 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""The `Audio` content class."""
22

3+
import base64
34
from dataclasses import dataclass
4-
from typing import Literal
5+
from pathlib import Path
6+
from typing import Literal, get_args
7+
8+
import httpx
59

610
AudioMimeType = Literal[
711
"audio/wav",
@@ -12,6 +16,35 @@
1216
"audio/flac",
1317
]
1418

19+
MIME_TYPES = get_args(AudioMimeType)
20+
21+
# Maximum audio size in bytes (25MB)
22+
MAX_AUDIO_SIZE = 25 * 1024 * 1024
23+
24+
25+
def infer_audio_type(audio_data: bytes) -> AudioMimeType:
26+
"""Get the MIME type of an audio file from its raw bytes.
27+
28+
Raises:
29+
ValueError: If the audio type cannot be determined or data is too small
30+
"""
31+
if len(audio_data) < 12:
32+
raise ValueError("Audio data too small to determine type (minimum 12 bytes)")
33+
34+
if audio_data.startswith(b"RIFF") and audio_data[8:12] == b"WAVE":
35+
return "audio/wav"
36+
elif audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb"):
37+
return "audio/mp3"
38+
elif audio_data.startswith(b"FORM") and audio_data[8:12] == b"AIFF":
39+
return "audio/aiff"
40+
elif audio_data.startswith(b"\xff\xf1") or audio_data.startswith(b"\xff\xf9"):
41+
return "audio/aac"
42+
elif audio_data.startswith(b"OggS"):
43+
return "audio/ogg"
44+
elif audio_data.startswith(b"fLaC"):
45+
return "audio/flac"
46+
raise ValueError("Unsupported audio type")
47+
1548

1649
@dataclass(kw_only=True)
1750
class Base64AudioSource:
@@ -38,29 +71,133 @@ class Audio:
3871
source: Base64AudioSource
3972

4073
@classmethod
41-
def from_url(
42-
cls,
43-
url: str,
44-
) -> "Audio":
45-
"""Create an `Audio` from a URL."""
46-
raise NotImplementedError
74+
def download(cls, url: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
75+
"""Download and encode an audio file from a URL.
76+
77+
Args:
78+
url: The URL of the audio file to download
79+
max_size: Maximum allowed audio size in bytes (default: 25MB)
80+
81+
Returns:
82+
An `Audio` with a `Base64AudioSource`
83+
84+
Raises:
85+
ValueError: If the downloaded audio exceeds max_size
86+
"""
87+
response = httpx.get(url, follow_redirects=True)
88+
response.raise_for_status()
89+
90+
content_length = len(response.content)
91+
if content_length > max_size:
92+
raise ValueError(
93+
f"Audio size ({content_length} bytes) exceeds maximum allowed size ({max_size} bytes)"
94+
)
95+
96+
mime_type = infer_audio_type(response.content)
97+
data = base64.b64encode(response.content).decode("utf-8")
98+
99+
return cls(
100+
source=Base64AudioSource(
101+
type="base64_audio_source",
102+
data=data,
103+
mime_type=mime_type,
104+
)
105+
)
47106

48107
@classmethod
49-
def from_file(
50-
cls,
51-
file_path: str,
52-
*,
53-
mime_type: AudioMimeType | None,
108+
async def download_async(
109+
cls, url: str, *, max_size: int = MAX_AUDIO_SIZE
54110
) -> "Audio":
55-
"""Create an `Audio` from a file path."""
56-
raise NotImplementedError
111+
"""Asynchronously download and encode an audio file from a URL.
112+
113+
Args:
114+
url: The URL of the audio file to download
115+
max_size: Maximum allowed audio size in bytes (default: 25MB)
116+
117+
Returns:
118+
An `Audio` with a `Base64AudioSource`
119+
120+
Raises:
121+
ValueError: If the downloaded audio exceeds max_size
122+
"""
123+
async with httpx.AsyncClient() as client:
124+
response = await client.get(url, follow_redirects=True)
125+
response.raise_for_status()
126+
127+
content_length = len(response.content)
128+
if content_length > max_size:
129+
raise ValueError(
130+
f"Audio size ({content_length} bytes) exceeds maximum allowed size ({max_size} bytes)"
131+
)
132+
133+
mime_type = infer_audio_type(response.content)
134+
data = base64.b64encode(response.content).decode("utf-8")
135+
136+
return cls(
137+
source=Base64AudioSource(
138+
type="base64_audio_source",
139+
data=data,
140+
mime_type=mime_type,
141+
)
142+
)
57143

58144
@classmethod
59-
def from_bytes(
60-
cls,
61-
data: bytes,
62-
*,
63-
mime_type: AudioMimeType | None,
64-
) -> "Audio":
65-
"""Create an `Audio` from raw bytes."""
66-
raise NotImplementedError
145+
def from_file(cls, file_path: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
146+
"""Create an `Audio` from a file path.
147+
148+
Args:
149+
file_path: Path to the audio file
150+
max_size: Maximum allowed audio size in bytes (default: 25MB)
151+
152+
Raises:
153+
FileNotFoundError: If the file does not exist
154+
ValueError: If the file size exceeds max_size
155+
"""
156+
path = Path(file_path)
157+
file_size = path.stat().st_size
158+
if file_size > max_size:
159+
raise ValueError(
160+
f"Audio file size ({file_size} bytes) exceeds maximum allowed size ({max_size} bytes)"
161+
)
162+
163+
with open(path, "rb") as f:
164+
audio_bytes = f.read()
165+
166+
mime_type = infer_audio_type(audio_bytes)
167+
data = base64.b64encode(audio_bytes).decode("utf-8")
168+
169+
return cls(
170+
source=Base64AudioSource(
171+
type="base64_audio_source",
172+
data=data,
173+
mime_type=mime_type,
174+
)
175+
)
176+
177+
@classmethod
178+
def from_bytes(cls, data: bytes, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
179+
"""Create an `Audio` from raw bytes.
180+
181+
Args:
182+
data: Raw audio bytes
183+
max_size: Maximum allowed audio size in bytes (default: 25MB)
184+
185+
Raises:
186+
ValueError: If the data size exceeds max_size
187+
"""
188+
data_size = len(data)
189+
if data_size > max_size:
190+
raise ValueError(
191+
f"Audio data size ({data_size} bytes) exceeds maximum allowed size ({max_size} bytes)"
192+
)
193+
194+
mime_type = infer_audio_type(data)
195+
encoded_data = base64.b64encode(data).decode("utf-8")
196+
197+
return cls(
198+
source=Base64AudioSource(
199+
type="base64_audio_source",
200+
data=encoded_data,
201+
mime_type=mime_type,
202+
)
203+
)

0 commit comments

Comments
 (0)