diff --git a/python/mirascope/llm/content/audio.py b/python/mirascope/llm/content/audio.py index 0e233a3f4..bfae35c33 100644 --- a/python/mirascope/llm/content/audio.py +++ b/python/mirascope/llm/content/audio.py @@ -1,7 +1,11 @@ """The `Audio` content class.""" +import base64 from dataclasses import dataclass -from typing import Literal +from pathlib import Path +from typing import Literal, get_args + +import httpx AudioMimeType = Literal[ "audio/wav", @@ -12,6 +16,35 @@ "audio/flac", ] +MIME_TYPES = get_args(AudioMimeType) + +# Maximum audio size in bytes (25MB) +MAX_AUDIO_SIZE = 25 * 1024 * 1024 + + +def infer_audio_type(audio_data: bytes) -> AudioMimeType: + """Get the MIME type of an audio file from its raw bytes. + + Raises: + ValueError: If the audio type cannot be determined or data is too small + """ + if len(audio_data) < 12: + raise ValueError("Audio data too small to determine type (minimum 12 bytes)") + + if audio_data.startswith(b"RIFF") and audio_data[8:12] == b"WAVE": + return "audio/wav" + elif audio_data.startswith(b"ID3") or audio_data.startswith(b"\xff\xfb"): + return "audio/mp3" + elif audio_data.startswith(b"FORM") and audio_data[8:12] == b"AIFF": + return "audio/aiff" + elif audio_data.startswith(b"\xff\xf1") or audio_data.startswith(b"\xff\xf9"): + return "audio/aac" + elif audio_data.startswith(b"OggS"): + return "audio/ogg" + elif audio_data.startswith(b"fLaC"): + return "audio/flac" + raise ValueError("Unsupported audio type") + @dataclass(kw_only=True) class Base64AudioSource: @@ -26,6 +59,34 @@ class Base64AudioSource: """The mime type of the audio (e.g. audio/mp3).""" +def _process_audio_bytes(data: bytes, max_size: int) -> Base64AudioSource: + """Validate and process audio bytes into a Base64AudioSource. + + Args: + data: Raw audio bytes + max_size: Maximum allowed size in bytes + + Returns: + A Base64AudioSource with validated and encoded data + + Raises: + ValueError: If data size exceeds max_size + """ + size = len(data) + if size > max_size: + raise ValueError( + f"Audio size ({size} bytes) exceeds maximum allowed size ({max_size} bytes)" + ) + + mime_type = infer_audio_type(data) + encoded_data = base64.b64encode(data).decode("utf-8") + return Base64AudioSource( + type="base64_audio_source", + data=encoded_data, + mime_type=mime_type, + ) + + @dataclass(kw_only=True) class Audio: """Audio content for a message. @@ -38,29 +99,75 @@ class Audio: source: Base64AudioSource @classmethod - def from_url( - cls, - url: str, - ) -> "Audio": - """Create an `Audio` from a URL.""" - raise NotImplementedError + def download(cls, url: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio": + """Download and encode an audio file from a URL. + + Args: + url: The URL of the audio file to download + max_size: Maximum allowed audio size in bytes (default: 25MB) + + Returns: + An `Audio` with a `Base64AudioSource` + + Raises: + ValueError: If the downloaded audio exceeds max_size + """ + response = httpx.get(url, follow_redirects=True) + response.raise_for_status() + return cls(source=_process_audio_bytes(response.content, max_size)) @classmethod - def from_file( - cls, - file_path: str, - *, - mime_type: AudioMimeType | None, + async def download_async( + cls, url: str, *, max_size: int = MAX_AUDIO_SIZE ) -> "Audio": - """Create an `Audio` from a file path.""" - raise NotImplementedError + """Asynchronously download and encode an audio file from a URL. + + Args: + url: The URL of the audio file to download + max_size: Maximum allowed audio size in bytes (default: 25MB) + + Returns: + An `Audio` with a `Base64AudioSource` + + Raises: + ValueError: If the downloaded audio exceeds max_size + """ + async with httpx.AsyncClient() as client: + response = await client.get(url, follow_redirects=True) + response.raise_for_status() + return cls(source=_process_audio_bytes(response.content, max_size)) @classmethod - def from_bytes( - cls, - data: bytes, - *, - mime_type: AudioMimeType | None, - ) -> "Audio": - """Create an `Audio` from raw bytes.""" - raise NotImplementedError + def from_file(cls, file_path: str, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio": + """Create an `Audio` from a file path. + + Args: + file_path: Path to the audio file + max_size: Maximum allowed audio size in bytes (default: 25MB) + + Raises: + FileNotFoundError: If the file does not exist + ValueError: If the file size exceeds max_size + """ + path = Path(file_path) + file_size = path.stat().st_size + if file_size > max_size: + raise ValueError( + f"Audio file size ({file_size} bytes) exceeds maximum allowed size ({max_size} bytes)" + ) + with open(path, "rb") as f: + audio_bytes = f.read() + return cls(source=_process_audio_bytes(audio_bytes, max_size)) + + @classmethod + def from_bytes(cls, data: bytes, *, max_size: int = MAX_AUDIO_SIZE) -> "Audio": + """Create an `Audio` from raw bytes. + + Args: + data: Raw audio bytes + max_size: Maximum allowed audio size in bytes (default: 25MB) + + Raises: + ValueError: If the data size exceeds max_size + """ + return cls(source=_process_audio_bytes(data, max_size)) diff --git a/python/tests/llm/content/test_audio_content.py b/python/tests/llm/content/test_audio_content.py new file mode 100644 index 000000000..14c0b8954 --- /dev/null +++ b/python/tests/llm/content/test_audio_content.py @@ -0,0 +1,265 @@ +"""Tests for Audio content class.""" + +import tempfile +from pathlib import Path + +import httpx +import pytest +from pytest_httpserver import HTTPServer + +from mirascope.llm.content.audio import MAX_AUDIO_SIZE, Audio, Base64AudioSource + + +@pytest.fixture +def audio_data() -> dict[str, bytes]: + """Provide sample audio data for different formats (minimum 12 bytes each).""" + return { + "wav": b"RIFF\x00\x00\x00\x00WAVE", + "mp3_id3": b"ID3\x03\x00\x00\x00\x00\x00\x00\x00\x00", + "mp3_frame": b"\xff\xfb\x90\x00\x00\x00\x00\x00\x00\x00\x00\x00", + "aiff": b"FORM\x00\x00\x00\x00AIFF", + "aac_f1": b"\xff\xf1\x50\x80\x00\x00\x00\x00\x00\x00\x00\x00", + "aac_f9": b"\xff\xf9\x50\x80\x00\x00\x00\x00\x00\x00\x00\x00", + "ogg": b"OggS\x00\x02\x00\x00\x00\x00\x00\x00", + "flac": b"fLaC\x00\x00\x00\x22\x00\x00\x00\x00", + "unsupported": b"random unsupported data", + } + + +AUDIO_FORMAT_TESTS = [ + ("wav", "audio/wav"), + ("mp3_id3", "audio/mp3"), + ("mp3_frame", "audio/mp3"), + ("aiff", "audio/aiff"), + ("aac_f1", "audio/aac"), + ("aac_f9", "audio/aac"), + ("ogg", "audio/ogg"), + ("flac", "audio/flac"), +] + + +class TestAudioDownload: + """Tests for Audio.download class method.""" + + @pytest.mark.parametrize("format_name,expected_mime", AUDIO_FORMAT_TESTS) + def test_download_detects_mime_type_from_magic_bytes( + self, + httpserver: HTTPServer, + audio_data: dict, + format_name: str, + expected_mime: str, + ) -> None: + """Test that download() detects MIME type from magic bytes.""" + httpserver.expect_request("/audio").respond_with_data(audio_data[format_name]) + url = httpserver.url_for("/audio") + + audio = Audio.download(url) + + assert isinstance(audio.source, Base64AudioSource) + assert audio.source.type == "base64_audio_source" + assert audio.source.mime_type == expected_mime + assert len(audio.source.data) > 0 + + def test_download_follows_redirects( + self, httpserver: HTTPServer, audio_data: dict + ) -> None: + """Test that download follows redirects.""" + httpserver.expect_request("/redirect").respond_with_data( + "", status=302, headers={"Location": "/audio"} + ) + httpserver.expect_request("/audio").respond_with_data(audio_data["wav"]) + url = httpserver.url_for("/redirect") + + audio = Audio.download(url) + + assert isinstance(audio.source, Base64AudioSource) + assert audio.source.mime_type == "audio/wav" + + def test_download_raises_on_unsupported_format( + self, httpserver: HTTPServer, audio_data: dict + ) -> None: + """Test that download raises ValueError for unsupported audio format.""" + httpserver.expect_request("/audio").respond_with_data(audio_data["unsupported"]) + url = httpserver.url_for("/audio") + + with pytest.raises(ValueError, match="Unsupported audio type"): + Audio.download(url) + + def test_download_raises_on_http_error(self, httpserver: HTTPServer) -> None: + """Test that download raises on HTTP errors.""" + httpserver.expect_request("/audio").respond_with_data("", status=404) + url = httpserver.url_for("/audio") + + with pytest.raises(httpx.HTTPStatusError): + Audio.download(url) + + def test_download_enforces_size_limit(self, httpserver: HTTPServer) -> None: + """Test that download enforces the size limit.""" + large_data = b"RIFF\x00\x00\x00\x00WAVE" + b"x" * (MAX_AUDIO_SIZE + 1) + httpserver.expect_request("/audio").respond_with_data(large_data) + url = httpserver.url_for("/audio") + + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.download(url) + + def test_download_respects_custom_size_limit( + self, httpserver: HTTPServer, audio_data: dict + ) -> None: + """Test that download respects custom max_size parameter.""" + httpserver.expect_request("/audio").respond_with_data(audio_data["wav"]) + url = httpserver.url_for("/audio") + + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.download(url, max_size=10) + + +class TestAudioDownloadAsync: + """Tests for Audio.download_async class method.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "format_name,expected_mime", + [ + ("wav", "audio/wav"), + ("mp3_id3", "audio/mp3"), + ("aiff", "audio/aiff"), + ("ogg", "audio/ogg"), + ], + ) + async def test_download_async_detects_mime_type( + self, + httpserver: HTTPServer, + audio_data: dict, + format_name: str, + expected_mime: str, + ) -> None: + """Test that download_async() detects MIME type from magic bytes.""" + httpserver.expect_request("/audio").respond_with_data(audio_data[format_name]) + url = httpserver.url_for("/audio") + + audio = await Audio.download_async(url) + + assert isinstance(audio.source, Base64AudioSource) + assert audio.source.mime_type == expected_mime + + @pytest.mark.asyncio + async def test_download_async_enforces_size_limit( + self, httpserver: HTTPServer + ) -> None: + """Test that download_async enforces the size limit.""" + large_data = b"RIFF\x00\x00\x00\x00WAVE" + b"x" * (MAX_AUDIO_SIZE + 1) + httpserver.expect_request("/audio").respond_with_data(large_data) + url = httpserver.url_for("/audio") + + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + await Audio.download_async(url) + + @pytest.mark.asyncio + async def test_download_async_raises_on_unsupported_format( + self, httpserver: HTTPServer, audio_data: dict + ) -> None: + """Test that download_async raises ValueError for unsupported format.""" + httpserver.expect_request("/audio").respond_with_data(audio_data["unsupported"]) + url = httpserver.url_for("/audio") + + with pytest.raises(ValueError, match="Unsupported audio type"): + await Audio.download_async(url) + + +class TestAudioFromFile: + """Tests for Audio.from_file class method.""" + + @pytest.mark.parametrize("data_key,expected_mime", AUDIO_FORMAT_TESTS) + def test_from_file_detects_mime_type( + self, audio_data: dict, data_key: str, expected_mime: str + ) -> None: + """Test that from_file detects MIME type from magic bytes (not extension).""" + with tempfile.NamedTemporaryFile(suffix=".unknown", delete=False) as f: + f.write(audio_data[data_key]) + temp_path = f.name + + try: + audio = Audio.from_file(temp_path) + + assert isinstance(audio.source, Base64AudioSource) + assert audio.source.mime_type == expected_mime + finally: + Path(temp_path).unlink() + + def test_from_file_raises_on_unsupported_format(self, audio_data: dict) -> None: + """Test that from_file raises ValueError for unsupported format.""" + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: + f.write(audio_data["unsupported"]) + temp_path = f.name + + try: + with pytest.raises(ValueError, match="Unsupported audio type"): + Audio.from_file(temp_path) + finally: + Path(temp_path).unlink() + + def test_from_file_not_found(self) -> None: + """Test FileNotFoundError is raised for non-existent file.""" + with pytest.raises(FileNotFoundError): + Audio.from_file("/nonexistent/path/to/file.mp3") + + def test_from_file_enforces_size_limit(self, audio_data: dict) -> None: + """Test that from_file enforces the size limit.""" + large_data = b"RIFF\x00\x00\x00\x00WAVE" + b"x" * (MAX_AUDIO_SIZE + 1) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + f.write(large_data) + temp_path = f.name + + try: + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.from_file(temp_path) + finally: + Path(temp_path).unlink() + + def test_from_file_respects_custom_size_limit(self, audio_data: dict) -> None: + """Test that from_file respects custom max_size parameter.""" + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + f.write(audio_data["wav"]) + temp_path = f.name + + try: + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.from_file(temp_path, max_size=10) + finally: + Path(temp_path).unlink() + + +class TestAudioFromBytes: + """Tests for Audio.from_bytes class method.""" + + @pytest.mark.parametrize("data_key,expected_mime", AUDIO_FORMAT_TESTS) + def test_from_bytes_detects_mime_type( + self, audio_data: dict, data_key: str, expected_mime: str + ) -> None: + """Test that from_bytes detects MIME type from magic bytes.""" + audio = Audio.from_bytes(audio_data[data_key]) + + assert isinstance(audio.source, Base64AudioSource) + assert audio.source.mime_type == expected_mime + + def test_from_bytes_raises_on_unsupported_format(self, audio_data: dict) -> None: + """Test that from_bytes raises ValueError for unsupported format.""" + with pytest.raises(ValueError, match="Unsupported audio type"): + Audio.from_bytes(audio_data["unsupported"]) + + def test_from_bytes_raises_on_data_too_small(self) -> None: + """Test that from_bytes raises ValueError for data that's too small.""" + with pytest.raises(ValueError, match="Audio data too small to determine type"): + Audio.from_bytes(b"short") + + def test_from_bytes_enforces_size_limit(self) -> None: + """Test that from_bytes enforces the size limit.""" + large_data = b"RIFF\x00\x00\x00\x00WAVE" + b"x" * (MAX_AUDIO_SIZE + 1) + + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.from_bytes(large_data) + + def test_from_bytes_respects_custom_size_limit(self, audio_data: dict) -> None: + """Test that from_bytes respects custom max_size parameter.""" + with pytest.raises(ValueError, match="exceeds maximum allowed size"): + Audio.from_bytes(audio_data["wav"], max_size=10)