Skip to content

Commit c8a05c4

Browse files
committed
feat: implement llm.Audio
1 parent ac1e275 commit c8a05c4

File tree

2 files changed

+394
-22
lines changed

2 files changed

+394
-22
lines changed
Lines changed: 129 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""The `Audio` content class."""
22

3+
import base64
34
from dataclasses import dataclass
4-
from typing import Literal
5+
from pathlib import Path
6+
from typing import Literal, get_args
7+
8+
import httpx
59

610
AudioMimeType = Literal[
711
"audio/wav",
@@ -12,6 +16,35 @@
1216
"audio/flac",
1317
]
1418

19+
MIME_TYPES = get_args(AudioMimeType)
20+
21+
# Maximum audio size in bytes (25MB)
22+
MAX_AUDIO_SIZE = 25 * 1024 * 1024
23+
24+
25+
def infer_audio_type(audio_data: bytes) -> AudioMimeType:
26+
"""Get the MIME type of an audio file from its raw bytes.
27+
28+
Raises:
29+
ValueError: If the audio type cannot be determined or data is too small
30+
"""
31+
if len(audio_data) < 12:
32+
raise ValueError("Audio data too small to determine type (minimum 12 bytes)")
33+
34+
if audio_data.startswith(b"RIFF") and audio_data[8:12] == b"WAVE":
35+
return "audio/wav"
36+
elif audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb"):
37+
return "audio/mp3"
38+
elif audio_data.startswith(b"FORM") and audio_data[8:12] == b"AIFF":
39+
return "audio/aiff"
40+
elif audio_data.startswith(b"\xff\xf1") or audio_data.startswith(b"\xff\xf9"):
41+
return "audio/aac"
42+
elif audio_data.startswith(b"OggS"):
43+
return "audio/ogg"
44+
elif audio_data.startswith(b"fLaC"):
45+
return "audio/flac"
46+
raise ValueError("Unsupported audio type")
47+
1548

1649
@dataclass(kw_only=True)
1750
class Base64AudioSource:
@@ -26,6 +59,34 @@ class Base64AudioSource:
2659
"""The mime type of the audio (e.g. audio/mp3)."""
2760

2861

62+
def _process_audio_bytes(data: bytes, max_size: int) -> Base64AudioSource:
63+
"""Validate and process audio bytes into a Base64AudioSource.
64+
65+
Args:
66+
data: Raw audio bytes
67+
max_size: Maximum allowed size in bytes
68+
69+
Returns:
70+
A Base64AudioSource with validated and encoded data
71+
72+
Raises:
73+
ValueError: If data size exceeds max_size
74+
"""
75+
size = len(data)
76+
if size > max_size:
77+
raise ValueError(
78+
f"Audio size ({size} bytes) exceeds maximum allowed size ({max_size} bytes)"
79+
)
80+
81+
mime_type = infer_audio_type(data)
82+
encoded_data = base64.b64encode(data).decode("utf-8")
83+
return Base64AudioSource(
84+
type="base64_audio_source",
85+
data=encoded_data,
86+
mime_type=mime_type,
87+
)
88+
89+
2990
@dataclass(kw_only=True)
3091
class Audio:
3192
"""Audio content for a message.
@@ -38,29 +99,75 @@ class Audio:
3899
source: Base64AudioSource
39100

40101
@classmethod
41-
def from_url(
42-
cls,
43-
url: str,
44-
) -> "Audio":
45-
"""Create an `Audio` from a URL."""
46-
raise NotImplementedError
102+
def download(cls, url: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
103+
"""Download and encode an audio file from a URL.
104+
105+
Args:
106+
url: The URL of the audio file to download
107+
max_size: Maximum allowed audio size in bytes (default: 25MB)
108+
109+
Returns:
110+
An `Audio` with a `Base64AudioSource`
111+
112+
Raises:
113+
ValueError: If the downloaded audio exceeds max_size
114+
"""
115+
response = httpx.get(url, follow_redirects=True)
116+
response.raise_for_status()
117+
return cls(source=_process_audio_bytes(response.content, max_size))
47118

48119
@classmethod
49-
def from_file(
50-
cls,
51-
file_path: str,
52-
*,
53-
mime_type: AudioMimeType | None,
120+
async def download_async(
121+
cls, url: str, *, max_size: int = MAX_AUDIO_SIZE
54122
) -> "Audio":
55-
"""Create an `Audio` from a file path."""
56-
raise NotImplementedError
123+
"""Asynchronously download and encode an audio file from a URL.
124+
125+
Args:
126+
url: The URL of the audio file to download
127+
max_size: Maximum allowed audio size in bytes (default: 25MB)
128+
129+
Returns:
130+
An `Audio` with a `Base64AudioSource`
131+
132+
Raises:
133+
ValueError: If the downloaded audio exceeds max_size
134+
"""
135+
async with httpx.AsyncClient() as client:
136+
response = await client.get(url, follow_redirects=True)
137+
response.raise_for_status()
138+
return cls(source=_process_audio_bytes(response.content, max_size))
57139

58140
@classmethod
59-
def from_bytes(
60-
cls,
61-
data: bytes,
62-
*,
63-
mime_type: AudioMimeType | None,
64-
) -> "Audio":
65-
"""Create an `Audio` from raw bytes."""
66-
raise NotImplementedError
141+
def from_file(cls, file_path: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
142+
"""Create an `Audio` from a file path.
143+
144+
Args:
145+
file_path: Path to the audio file
146+
max_size: Maximum allowed audio size in bytes (default: 25MB)
147+
148+
Raises:
149+
FileNotFoundError: If the file does not exist
150+
ValueError: If the file size exceeds max_size
151+
"""
152+
path = Path(file_path)
153+
file_size = path.stat().st_size
154+
if file_size > max_size:
155+
raise ValueError(
156+
f"Audio file size ({file_size} bytes) exceeds maximum allowed size ({max_size} bytes)"
157+
)
158+
with open(path, "rb") as f:
159+
audio_bytes = f.read()
160+
return cls(source=_process_audio_bytes(audio_bytes, max_size))
161+
162+
@classmethod
163+
def from_bytes(cls, data: bytes, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio":
164+
"""Create an `Audio` from raw bytes.
165+
166+
Args:
167+
data: Raw audio bytes
168+
max_size: Maximum allowed audio size in bytes (default: 25MB)
169+
170+
Raises:
171+
ValueError: If the data size exceeds max_size
172+
"""
173+
return cls(source=_process_audio_bytes(data, max_size))

0 commit comments

Comments
 (0)