diff --git a/torch_uncertainty/datasets/assets/__init__.py b/torch_uncertainty/datasets/assets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/torch_uncertainty/datasets/assets/frost1.jpg b/torch_uncertainty/datasets/assets/frost1.jpg new file mode 100644 index 00000000..4745a0ab Binary files /dev/null and b/torch_uncertainty/datasets/assets/frost1.jpg differ diff --git a/torch_uncertainty/datasets/assets/frost2.jpg b/torch_uncertainty/datasets/assets/frost2.jpg new file mode 100644 index 00000000..1e7694dc Binary files /dev/null and b/torch_uncertainty/datasets/assets/frost2.jpg differ diff --git a/torch_uncertainty/datasets/assets/frost3.jpg b/torch_uncertainty/datasets/assets/frost3.jpg new file mode 100644 index 00000000..f8b0c413 Binary files /dev/null and b/torch_uncertainty/datasets/assets/frost3.jpg differ diff --git a/torch_uncertainty/datasets/assets/frost4.jpg b/torch_uncertainty/datasets/assets/frost4.jpg new file mode 100644 index 00000000..95dc9056 Binary files /dev/null and b/torch_uncertainty/datasets/assets/frost4.jpg differ diff --git a/torch_uncertainty/datasets/assets/frost5.jpg b/torch_uncertainty/datasets/assets/frost5.jpg new file mode 100644 index 00000000..14e5d58e Binary files /dev/null and b/torch_uncertainty/datasets/assets/frost5.jpg differ diff --git a/torch_uncertainty/datasets/frost.py b/torch_uncertainty/datasets/frost.py index 6aa069bb..e0ad660c 100644 --- a/torch_uncertainty/datasets/frost.py +++ b/torch_uncertainty/datasets/frost.py @@ -1,78 +1,36 @@ -import logging from collections.abc import Callable +from importlib.abc import Traversable +from importlib.resources import files from pathlib import Path from typing import Any from PIL import Image from torchvision.datasets import VisionDataset -from torchvision.datasets.utils import ( - check_integrity, - download_and_extract_archive, -) -def pil_loader(path: Path) -> Image.Image: - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) +def pil_loader(path: Path | Traversable) -> Image.Image: with path.open("rb") as f: img = Image.open(f) return img.convert("RGB") -class FrostImages(VisionDataset): # TODO: Use ImageFolder - url = "https://zenodo.org/records/10438904/files/frost.zip" - zip_md5 = "d82f29f620d43a68e71e34b28f7c35cb" - filename = "frost.zip" - samples = [ - "frost1.png", - "frost2.png", - "frost3.jpg", - "frost4.jpg", - "frost5.jpg", - ] +FROST_ASSETS_MOD = "torch_uncertainty.datasets.assets" + +class FrostImages(VisionDataset): def __init__( self, - root: str | Path, - transform: Callable[..., Any] | None, + transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, - download: bool = False, ) -> None: - self.root = Path(root) - - if download: - self.download() - - if not self._check_integrity(): - raise RuntimeError( - "Dataset not found or corrupted. You can use download=True to download it." - ) - super().__init__( - self.root / "frost", + FROST_ASSETS_MOD, transform=transform, target_transform=target_transform, ) self.loader = pil_loader - - def _check_integrity(self) -> bool: - fpath = self.root / self.filename - return check_integrity( - fpath, - self.zip_md5, - ) - - def download(self) -> None: - if self._check_integrity(): - logging.info("Files already downloaded and verified") - return - - download_and_extract_archive( - self.url, - download_root=self.root, - filename=self.filename, - md5=self.zip_md5, - ) - logging.info("Downloaded %s to %s.", self.filename, self.root) + sample_path = files(FROST_ASSETS_MOD) + self.samples = [sample_path.joinpath(f"frost{i}.jpg") for i in range(1, 6)] def __getitem__(self, index: int) -> Any: """Get the samples of the dataset. @@ -83,8 +41,7 @@ def __getitem__(self, index: int) -> Any: Returns: tuple: (sample, target) where target is class_index of the target class. """ - path = self.root / self.samples[index] - sample = self.loader(path) + sample = self.loader(self.samples[index]) if self.transform is not None: sample = self.transform(sample) return sample