Skip to content

feat: allow mime_type to be guessed for ByteStream #9573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions haystack/components/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@
from haystack.dataclasses import ByteStream


def get_bytestream_from_source(source: Union[str, Path, ByteStream]) -> ByteStream:
def get_bytestream_from_source(source: Union[str, Path, ByteStream], guess_mime_type: bool = False) -> ByteStream:
"""
Creates a ByteStream object from a source.

:param source:
A source to convert to a ByteStream. Can be a string (path to a file), a Path object, or a ByteStream.
:param guess_mime_type:
Whether to guess the mime type from the file.
:return:
A ByteStream object.
"""

if isinstance(source, ByteStream):
return source
if isinstance(source, (str, Path)):
bs = ByteStream.from_file_path(Path(source))
bs = ByteStream.from_file_path(Path(source), guess_mime_type=guess_mime_type)
bs.meta["file_path"] = str(source)
return bs
raise ValueError(f"Unsupported source type {type(source)}")
Expand Down
24 changes: 1 addition & 23 deletions haystack/components/routers/file_type_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
from haystack.dataclasses import ByteStream

CUSTOM_MIMETYPES = {
# we add markdown because it is not added by the mimetypes module
# see https://github.com/python/cpython/pull/17995
".md": "text/markdown",
".markdown": "text/markdown",
# we add msg because it is not added by the mimetypes module
".msg": "application/vnd.ms-outlook",
}


@component
class FileTypeRouter:
Expand Down Expand Up @@ -149,7 +140,7 @@ def run(
source = Path(source)

if isinstance(source, Path):
mime_type = self._get_mime_type(source)
mime_type = ByteStream._guess_mime_type(source)
elif isinstance(source, ByteStream):
mime_type = source.mime_type
else:
Expand All @@ -171,16 +162,3 @@ def run(
mime_types["unclassified"].append(source)

return dict(mime_types)

def _get_mime_type(self, path: Path) -> Optional[str]:
"""
Get the MIME type of the provided file path.

:param path: The file path to get the MIME type for.

:returns: The MIME type of the provided file path, or `None` if the MIME type cannot be determined.
"""
extension = path.suffix.lower()
mime_type = mimetypes.guess_type(path.as_posix())[0]
# lookup custom mappings if the mime type is not found
return CUSTOM_MIMETYPES.get(extension, mime_type)
33 changes: 32 additions & 1 deletion haystack/dataclasses/byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import mimetypes
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -32,15 +33,22 @@ def to_file(self, destination_path: Path) -> None:

@classmethod
def from_file_path(
cls, filepath: Path, mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None
cls,
filepath: Path,
mime_type: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
guess_mime_type: bool = False,
) -> "ByteStream":
"""
Create a ByteStream from the contents read from a file.

:param filepath: A valid path to a file.
:param mime_type: The mime type of the file.
:param meta: Additional metadata to be stored with the ByteStream.
:param guess_mime_type: Whether to guess the mime type from the file.
"""
if not mime_type and guess_mime_type:
mime_type = cls._guess_mime_type(filepath)
with open(filepath, "rb") as fd:
return cls(data=fd.read(), mime_type=mime_type, meta=meta or {})

Expand Down Expand Up @@ -100,3 +108,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "ByteStream":
:returns: A ByteStream instance.
"""
return ByteStream(data=bytes(data["data"]), meta=data.get("meta", {}), mime_type=data.get("mime_type"))

@staticmethod
def _guess_mime_type(path: Path) -> Optional[str]:
"""
Guess the MIME type of the provided file path.

:param path: The file path to get the MIME type for.

:returns: The MIME type of the provided file path, or `None` if the MIME type cannot be determined.
"""
custom_mimetypes = {
# we add markdown because it is not added by the mimetypes module
# see https://github.com/python/cpython/pull/17995
".md": "text/markdown",
".markdown": "text/markdown",
# we add msg because it is not added by the mimetypes module
".msg": "application/vnd.ms-outlook",
}

extension = path.suffix.lower()
mime_type = mimetypes.guess_type(path.as_posix())[0]
# lookup custom mappings if the mime type is not found
return custom_mimetypes.get(extension, mime_type)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add `guess_mime_type` parameter to `Bytestream.from_file_path()`
40 changes: 39 additions & 1 deletion test/components/converters/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import pytest

from haystack.components.converters.utils import normalize_metadata
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
from haystack.dataclasses import ByteStream


def test_normalize_metadata_None():
Expand Down Expand Up @@ -32,3 +33,40 @@ def test_normalize_metadata_list_of_wrong_size():
def test_normalize_metadata_other_type():
with pytest.raises(ValueError, match="meta must be either None, a dictionary or a list of dictionaries."):
normalize_metadata(({"a": 1},), sources_count=1)


def test_get_bytestream_from_path_object(tmp_path):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the test cases here, I suggest we add a test case for FileTypeRouter in test/components/routers/test_file_router.py checking that the parameter additional_mimetypes still works as expected. We're doing mimetypes.add_type(mime, ext) there and need to check that when get_bytestream_from_source it's taking into account these additional mimetypes, for example: {"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx"}.

bytes_ = b"hello world"
source = tmp_path / "test.txt"
source.write_bytes(bytes_)

bs = get_bytestream_from_source(source, guess_mime_type=True)

assert isinstance(bs, ByteStream)
assert bs.data == bytes_
assert bs.mime_type == "text/plain"
assert bs.meta["file_path"].endswith("test.txt")


def test_get_bytestream_from_string_path(tmp_path):
bytes_ = b"hello world"
source = tmp_path / "test.txt"
source.write_bytes(bytes_)

bs = get_bytestream_from_source(str(source), guess_mime_type=True)

assert isinstance(bs, ByteStream)
assert bs.data == bytes_
assert bs.mime_type == "text/plain"
assert bs.meta["file_path"].endswith("test.txt")


def test_get_bytestream_from_source_invalid_type():
with pytest.raises(ValueError, match="Unsupported source type"):
get_bytestream_from_source(123)


def test_get_bytestream_from_source_bytestream_passthrough():
bs = ByteStream(data=b"spam", mime_type="text/custom", meta={"spam": "eggs"})
result = get_bytestream_from_source(bs)
assert result is bs
51 changes: 51 additions & 0 deletions test/components/routers/test_file_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
# SPDX-License-Identifier: Apache-2.0

import io
import mimetypes
import sys
from pathlib import PosixPath
from unittest.mock import mock_open, patch

import pytest
from packaging import version

import haystack
from haystack import Pipeline
from haystack.components.converters import PyPDFToDocument, TextFileToDocument
from haystack.components.routers.file_type_router import FileTypeRouter
Expand Down Expand Up @@ -395,3 +399,50 @@ def test_pipeline_with_converters(self, test_files_path):

assert output["text_file_converter"]["documents"][0].meta["meta_field_1"] == "meta_value_1"
assert output["pypdf_converter"]["documents"][0].meta["meta_field_2"] == "meta_value_2"

def test_additional_mimetypes_integration(self, tmp_path):
"""
Test if the component runs correctly in a pipeline with additional mimetypes correctly.
"""
custom_mime_type = "application/x-spam"
custom_extension = ".spam"
test_file = tmp_path / f"test.{custom_extension}"
test_file.touch()

# confirm that mimetypes module doesn't know about this extension by default
assert custom_mime_type not in mimetypes.types_map.values()

# make haystack aware of the custom mime type
router = FileTypeRouter(
mime_types=[custom_mime_type], additional_mimetypes={custom_mime_type: custom_extension}
)
mappings = router.run(sources=[test_file])

# assert the file was classified under the custom mime type
assert custom_mime_type in mappings
assert test_file in mappings[custom_mime_type]

@pytest.mark.skipif(
version.parse(haystack.__version__) >= version.parse("2.17.0"),
reason="https://github.com/deepset-ai/haystack/pull/9573#issuecomment-3045237341",
)
def test_non_existent_file(self):
"""
Test conditional FileNotFoundError behavior in FileTypeRouter.

In Haystack versions prior to 2.17.0, `FileTypeRouter` does not raise an error
when a non-existent file is passed without `meta`. However, it raises a
FileNotFoundError when the same file is passed with `meta` supplied.

This inconsistent behavior is slated to change in 2.17.0.
See: https://github.com/deepset-ai/haystack/pull/9573#issuecomment-3045237341
"""
router = FileTypeRouter(mime_types=[r"text/plain"])

# No meta - does not raise error
result = router.run(sources=["non_existent.txt"])
assert result == {"text/plain": [PosixPath("non_existent.txt")]}

# With meta - raises FileNotFoundError
with pytest.raises(FileNotFoundError):
router.run(sources=["non_existent.txt"], meta={"spam": "eggs"})
46 changes: 46 additions & 0 deletions test/dataclasses/test_byte_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,52 @@ def test_from_file_path(tmp_path, request):
assert b.meta == {"foo": "bar"}


@pytest.mark.parametrize(
"file_path, expected_mime_types",
[
("spam.jpeg", {"image/jpeg"}),
("spam.jpg", {"image/jpeg"}),
("spam.png", {"image/png"}),
("spam.gif", {"image/gif"}),
("spam.svg", {"image/svg+xml"}),
("spam.js", {"text/javascript", "application/javascript"}),
("spam.txt", {"text/plain"}),
("spam.html", {"text/html"}),
("spam.htm", {"text/html"}),
("spam.css", {"text/css"}),
("spam.csv", {"text/csv"}),
("spam.md", {"text/markdown"}), # custom mapping
("spam.markdown", {"text/markdown"}), # custom mapping
("spam.msg", {"application/vnd.ms-outlook"}), # custom mapping
("spam.pdf", {"application/pdf"}),
("spam.xml", {"application/xml", "text/xml"}),
("spam.json", {"application/json"}),
("spam.doc", {"application/msword"}),
("spam.docx", {"application/vnd.openxmlformats-officedocument.wordprocessingml.document"}),
("spam.xls", {"application/vnd.ms-excel"}),
("spam.xlsx", {"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}),
("spam.ppt", {"application/vnd.ms-powerpoint"}),
("spam.pptx", {"application/vnd.openxmlformats-officedocument.presentationml.presentation"}),
],
)
def test_from_file_path_guess_mime_type(file_path, expected_mime_types, tmp_path):
test_file = tmp_path / file_path
test_file.touch()

b = ByteStream.from_file_path(test_file, guess_mime_type=True)
assert b.mime_type in expected_mime_types


def test_explicit_mime_type_is_not_overwritten_by_guessing(tmp_path):
# create empty file with correct extension
test_file = tmp_path / "sample.md"
test_file.touch()

explicit_mime_type = "text/x-rst"
b = ByteStream.from_file_path(test_file, mime_type=explicit_mime_type, guess_mime_type=True)
assert b.mime_type == explicit_mime_type


def test_from_string():
test_string = "Hello, world!"
b = ByteStream.from_string(test_string)
Expand Down
Loading