diff --git a/CHANGELOG.md b/CHANGELOG.md index a3d9105af3..84361325db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - πŸš€ Add new SOTA video Anomaly detection module FUVAS - πŸš€ Add VAD dataset by @abc-125 in https://github.com/open-edge-platform/anomalib/pull/2603 - πŸš€ Add Tiled Ensemble for V2 by @blaz-r in https://github.com/open-edge-platform/anomalib/pull/2660 +- πŸš€ Add Tabular datamodule by @manuelkonrad in https://github.com/openvinotoolkit/anomalib/pull/2713 ### Removed diff --git a/docs/source/markdown/guides/reference/data/datamodules/image.md b/docs/source/markdown/guides/reference/data/datamodules/image.md index 8dfb26efd0..4950a3c54a 100644 --- a/docs/source/markdown/guides/reference/data/datamodules/image.md +++ b/docs/source/markdown/guides/reference/data/datamodules/image.md @@ -28,6 +28,13 @@ Dataset format compatible with Intel Getiβ„’. Custom folder-based dataset organization. ::: +:::{grid-item-card} Tabular +:link: anomalib.data.datamodules.image.Tabular +:link-type: doc + +Custom tabular dataset. +::: + :::{grid-item-card} Kolektor :link: anomalib.data.datamodules.image.Kolektor :link-type: doc diff --git a/docs/source/markdown/guides/reference/data/datamodules/image/tabular.md b/docs/source/markdown/guides/reference/data/datamodules/image/tabular.md new file mode 100644 index 0000000000..1a52f7ab0b --- /dev/null +++ b/docs/source/markdown/guides/reference/data/datamodules/image/tabular.md @@ -0,0 +1,7 @@ +# Tabular Datamodule + +```{eval-rst} +.. automodule:: anomalib.data.datamodules.image.tabular + :members: + :show-inheritance: +``` diff --git a/examples/configs/data/tabular.yaml b/examples/configs/data/tabular.yaml new file mode 100644 index 0000000000..70f23c8758 --- /dev/null +++ b/examples/configs/data/tabular.yaml @@ -0,0 +1,73 @@ +class_path: anomalib.data.Tabular +init_args: + name: bottle + root: "datasets/MVTecAD/bottle" + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + seed: null + samples: + - image_path: train/good/000.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/001.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/002.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/003.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/004.png + label_index: 0 + mask_path: "" + split: train + - image_path: test/broken_large/000.png + label_index: 1 + mask_path: ground_truth/broken_large/000_mask.png + split: test + - image_path: test/broken_large/002.png + label_index: 1 + mask_path: ground_truth/broken_large/002_mask.png + split: test + - image_path: test/broken_large/004.png + label_index: 1 + mask_path: ground_truth/broken_large/004_mask.png + split: test + - image_path: test/good/000.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/001.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/003.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/broken_large/001.png + label_index: 1 + mask_path: ground_truth/broken_large/001_mask.png + split: test + - image_path: test/broken_large/003.png + label_index: 1 + mask_path: ground_truth/broken_large/003_mask.png + split: test + - image_path: test/good/002.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/004.png + label_index: 0 + mask_path: "" + split: test diff --git a/examples/notebooks/100_datamodules/105_tabular.ipynb b/examples/notebooks/100_datamodules/105_tabular.ipynb new file mode 100644 index 0000000000..6e947af6c2 --- /dev/null +++ b/examples/notebooks/100_datamodules/105_tabular.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e38748d6c72f0c5a115c1e925b6012c72be1cf2529f03d508f533ca11634033c +size 9777 diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index baa8b2d78e..435f86e841 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -61,6 +61,7 @@ MVTecAD2, MVTecLOCO, RealIAD, + Tabular, Visa, ) from .datamodules.video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat @@ -75,6 +76,7 @@ KolektorDataset, MVTecADDataset, MVTecLOCODataset, + TabularDataset, VADDataset, VisaDataset, ) @@ -181,6 +183,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "MVTecAD2", "MVTecLOCO", "RealIAD", + "Tabular", "VAD", "Visa", # Video Data Modules @@ -196,6 +199,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "KolektorDataset", "MVTecADDataset", "MVTecLOCODataset", + "TabularDataset", "VADDataset", "VisaDataset", "AvenueDataset", diff --git a/src/anomalib/data/datamodules/__init__.py b/src/anomalib/data/datamodules/__init__.py index 24b1a8be25..3bcccc12d4 100644 --- a/src/anomalib/data/datamodules/__init__.py +++ b/src/anomalib/data/datamodules/__init__.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from .depth import Folder3D, MVTec3D -from .image import VAD, BTech, Datumaro, Folder, Kolektor, MVTec, MVTecAD, Visa +from .image import VAD, BTech, Datumaro, Folder, Kolektor, MVTec, MVTecAD, Tabular, Visa from .video import Avenue, ShanghaiTech, UCSDped __all__ = [ @@ -16,6 +16,7 @@ "Kolektor", "MVTec", # Include MVTec for backward compatibility "MVTecAD", + "Tabular", "VAD", "Visa", "Avenue", diff --git a/src/anomalib/data/datamodules/image/__init__.py b/src/anomalib/data/datamodules/image/__init__.py index c5fb17b59a..d2a0b74929 100644 --- a/src/anomalib/data/datamodules/image/__init__.py +++ b/src/anomalib/data/datamodules/image/__init__.py @@ -10,6 +10,7 @@ - ``MVTecAD``: MVTec Anomaly Detection Dataset - ``MVTecAD2``: MVTec Anomaly Detection Dataset 2 - ``MVTecLOCO``: MVTec LOCO Dataset with logical and structural anomalies +- ``Tabular``: Custom tabular dataset with image paths and labels - ``VAD``: Valeo Anomaly Detection Dataset - ``Visa``: Visual Anomaly Dataset @@ -36,6 +37,7 @@ from .mvtecad import MVTec, MVTecAD from .mvtecad2 import MVTecAD2 from .realiad import RealIAD +from .tabular import Tabular from .vad import VAD from .visa import Visa @@ -54,6 +56,7 @@ class ImageDataFormat(str, Enum): - ``MVTEC_AD_2``: MVTec AD 2 Dataset - ``MVTEC_3D``: MVTec 3D AD Dataset - ``MVTEC_LOCO``: MVTec LOCO Dataset + - ``TABULAR``: Custom Tabular Dataset - ``REALIAD``: Real-IAD Dataset - ``VAD``: Valeo Anomaly Detection Dataset - ``VISA``: Visual Anomaly Dataset @@ -69,6 +72,7 @@ class ImageDataFormat(str, Enum): MVTEC_3D = "mvtec_3d" MVTEC_LOCO = "mvtec_loco" REAL_IAD = "realiad" + TABULAR = "tabular" VAD = "vad" VISA = "visa" @@ -83,6 +87,7 @@ class ImageDataFormat(str, Enum): "MVTecAD2", "MVTecLOCO", "RealIAD", + "Tabular", "VAD", "Visa", ] diff --git a/src/anomalib/data/datamodules/image/tabular.py b/src/anomalib/data/datamodules/image/tabular.py new file mode 100644 index 0000000000..15d525e40b --- /dev/null +++ b/src/anomalib/data/datamodules/image/tabular.py @@ -0,0 +1,238 @@ +"""Custom Tabular Data Module. + +This script creates a custom Lightning DataModule from a table or tabular file +containing image paths and labels. + +Example: + Create a Tabular datamodule:: + + >>> from anomalib.data import Tabular + >>> samples = { + ... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ], + ... } + >>> datamodule = Tabular( + ... name="custom", + ... samples=samples, + ... root="./datasets/custom", + ... ) +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pandas as pd +from torchvision.transforms.v2 import Transform + +from anomalib.data.datamodules.base.image import AnomalibDataModule +from anomalib.data.datasets.image.tabular import TabularDataset +from anomalib.data.utils import Split, TestSplitMode, ValSplitMode + + +class Tabular(AnomalibDataModule): + """Tabular DataModule. + + Args: + name (str): Name of the dataset. Used for logging/saving. + samples (dict | list | DataFrame): Pandas ``DataFrame`` or compatible ``list`` + or ``dict`` containing the dataset information. + root (str | Path | None): Root folder containing normal and abnormal + directories. Defaults to ``None``. + normal_split_ratio (float): Ratio to split normal training images for + test set when no normal test images exist. + Defaults to ``0.2``. + train_batch_size (int): Training batch size. + Defaults to ``32``. + eval_batch_size (int): Validation/test batch size. + Defaults to ``32``. + num_workers (int): Number of workers for data loading. + Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply to the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. + test_split_mode (TestSplitMode): Method to obtain test subset. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of train images for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Method to obtain validation subset. + Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float): Fraction of images for validation. + Defaults to ``0.5``. + seed (int | None): Random seed for splitting. + Defaults to ``None``. + + Example: + Create and setup a tabular datamodule:: + + >>> from anomalib.data import Tabular + >>> samples = { + ... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ], + ... } + >>> datamodule = Tabular( + ... name="custom", + ... samples=samples, + ... root="./datasets/custom", + ... ) + >>> datamodule.setup() + + Get a batch from train dataloader:: + + >>> batch = next(iter(datamodule.train_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) + + Get a batch from test dataloader:: + + >>> batch = next(iter(datamodule.test_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) + """ + + def __init__( + self, + name: str, + samples: dict | list | pd.DataFrame, + root: str | Path | None = None, + normal_split_ratio: float = 0.2, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + self._name = name + self.root = root + self._unprocessed_samples = samples + test_split_mode = TestSplitMode(test_split_mode) + val_split_mode = ValSplitMode(val_split_mode) + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.normal_split_ratio = normal_split_ratio + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = TabularDataset( + name=self.name, + samples=self._unprocessed_samples, + split=Split.TRAIN, + root=self.root, + ) + + self.test_data = TabularDataset( + name=self.name, + samples=self._unprocessed_samples, + split=Split.TEST, + root=self.root, + ) + + @property + def name(self) -> str: + """Get name of the datamodule. + + Returns: + Name of the datamodule. + """ + return self._name + + @classmethod + def from_file( + cls: type["Tabular"], + name: str, + file_path: str | Path, + file_format: str | None = None, + pd_kwargs: dict | None = None, + **kwargs, + ) -> "Tabular": + """Create Tabular Datamodule from file. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, + especially when logging/saving. + file_path (str | Path): Path to tabular file containing the datset + information. + file_format (str): File format supported by a pd.read_* method, such + as ``csv``, ``parquet`` or ``json``. + Defaults to ``None`` (inferred from file suffix). + pd_kwargs (dict | None): Keyword argument dictionary for the pd.read_* method. + Defaults to ``None``. + kwargs (dict): Additional keyword arguments for the Tabular Datamodule class. + + Returns: + Tabular: Tabular Datamodule + + Example: + Prepare a tabular file (such as ``samples.csv`` or ``samples.parquet``) with the + following columns: ``image_path`` (absolute or relative to ``root``), ``label_index`` + (``0`` for normal, ``1`` for anomalous samples), and ``split`` (``train`` or ``test``). + For segmentation tasks, also include a ``mask_path`` column. + + From this file, create and setup a tabular datamodule:: + + >>> from anomalib.data import Tabular + >>> datamodule = Tabular.from_file( + ... name="custom", + ... file_path="./samples.csv", + ... root="./datasets/custom", + ... ) + >>> datamodule.setup() + + Get a batch from train dataloader:: + + >>> batch = next(iter(datamodule.train_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) + + Get a batch from test dataloader:: + + >>> batch = next(iter(datamodule.test_dataloader())) + >>> batch.keys() + dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path']) + """ + # Check if file exists + if not Path(file_path).is_file(): + msg = f"File not found: '{file_path}'" + raise FileNotFoundError(msg) + + # Infer file_format and check if supported + file_format = file_format or Path(file_path).suffix[1:] + if not file_format: + msg = f"File format not specified and could not be inferred from file name: '{Path(file_path).name}'" + raise ValueError(msg) + read_func = getattr(pd, f"read_{file_format}", None) + if read_func is None: + msg = f"Unsupported file format: '{file_format}'" + raise ValueError(msg) + + # Read the file and return Tabular dataset + pd_kwargs = pd_kwargs or {} + samples = read_func(file_path, **pd_kwargs) + return cls(name, samples, **kwargs) diff --git a/src/anomalib/data/datasets/__init__.py b/src/anomalib/data/datasets/__init__.py index e6a53c40f2..62d67f4f91 100644 --- a/src/anomalib/data/datasets/__init__.py +++ b/src/anomalib/data/datasets/__init__.py @@ -17,6 +17,7 @@ - ``FolderDataset``: Custom dataset from folder structure - ``KolektorDataset``: Kolektor surface defect dataset - ``MVTecADDataset``: MVTec AD dataset with industrial objects + - ``TabularDataset``: Custom tabular dataset with image paths and labels - ``VAD``: Valeo Anomaly Detection Dataset - ``VisaDataset``: Visual Anomaly dataset @@ -45,6 +46,7 @@ FolderDataset, KolektorDataset, MVTecADDataset, + TabularDataset, VADDataset, VisaDataset, ) @@ -64,6 +66,7 @@ "FolderDataset", "KolektorDataset", "MVTecADDataset", + "TabularDataset", "VADDataset", "VisaDataset", # Video diff --git a/src/anomalib/data/datasets/image/__init__.py b/src/anomalib/data/datasets/image/__init__.py index d7934672c1..b205671241 100644 --- a/src/anomalib/data/datasets/image/__init__.py +++ b/src/anomalib/data/datasets/image/__init__.py @@ -9,6 +9,7 @@ - ``KolektorDataset``: Kolektor surface defect dataset - ``MVTecADDataset``: MVTec AD dataset with industrial objects - ``MVTecLOCODataset``: MVTec LOCO dataset with logical and structural anomalies +- ``TabularDataset``: Custom tabular dataset with image paths and labels - ``VAD``: Valeo Anomaly Detection Dataset - ``VisaDataset``: Visual Anomaly dataset @@ -32,6 +33,7 @@ from .mvtecad import MVTecADDataset, MVTecDataset from .mvtecad2 import MVTecAD2Dataset from .realiad import RealIADDataset +from .tabular import TabularDataset from .vad import VADDataset from .visa import VisaDataset @@ -45,6 +47,7 @@ "MVTecAD2Dataset", "MVTecLOCODataset", "RealIADDataset", + "TabularDataset", "VADDataset", "VisaDataset", ] diff --git a/src/anomalib/data/datasets/image/tabular.py b/src/anomalib/data/datasets/image/tabular.py new file mode 100644 index 0000000000..4641ae9185 --- /dev/null +++ b/src/anomalib/data/datasets/image/tabular.py @@ -0,0 +1,297 @@ +"""Custom Tabular Dataset. + +This module provides a custom PyTorch Dataset implementation for loading +images using a selection of paths and labels defined in a table or tabular file. +It does not require a specific folder structure and allows subsampling and +relabeling without moving files. The dataset supports both classification and +segmentation tasks. + +The table should contain columns for ``image_paths``, ``label_index``, ``split``, +and optionally ``masks_paths`` for segmentation tasks. + +Example: + >>> from anomalib.data.datasets import TabularDataset + >>> samples = { + ... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ], + ... } + >>> dataset = TabularDataset( + ... name="custom", + ... samples=samples, + ... root="./datasets/custom", + ... ) +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib.data.datasets.base.image import AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import DirType, LabelName, Split + + +class TabularDataset(AnomalibDataset): + """Dataset class for loading images from paths and labels defined in a table. + + Args: + name (str): Name of the dataset. Used for logging/saving. + samples (dict | list | DataFrame): Pandas ``DataFrame`` or compatible ``list`` + or ``dict`` containing the dataset information. + augmentations (Transform | None, optional): Augmentations to apply to the images. + Defaults to ``None``. + root (str | Path | None, optional): Root directory of the dataset. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split to load. + Choose from ``Split.FULL``, ``Split.TRAIN``, ``Split.TEST``. + Defaults to ``None``. + + Examples: + Create a classification dataset: + + >>> from anomalib.data.utils import InputNormalizationMethod, get_transforms + >>> from anomalib.data.datasets import TabularDataset + >>> transform = get_transforms( + ... image_size=256, + ... normalization=InputNormalizationMethod.NONE + ... ) + >>> samples = { + ... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ], + ... } + >>> dataset = TabularDataset( + ... name="custom", + ... samples=samples, + ... root="./datasets/custom", + ... transform=transform + ... ) + + Create a segmentation dataset: + + >>> samples = { + ... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ], + ... "mask_path": ["masks/mask1.png", "masks/mask2.png", "masks/mask3.png", ... ], + ... } + >>> dataset = TabularDataset( + ... name="custom", + ... samples=samples, + ... root="./datasets/custom", + ... transform=transform + ... ) + """ + + def __init__( + self, + name: str, + samples: dict | list | DataFrame, + augmentations: Transform | None = None, + root: str | Path | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(augmentations=augmentations) + + self._name = name + self.split = split + self.root = root + self.samples = make_tabular_dataset( + samples=samples, + root=self.root, + split=self.split, + ) + + @property + def name(self) -> str: + """Get dataset name. + + Returns: + str: Name of the dataset + """ + return self._name + + +def make_tabular_dataset( + samples: dict | list | DataFrame, + root: str | Path | None = None, + split: str | Split | None = None, +) -> DataFrame: + """Create a dataset from a table of image paths and labels. + + Args: + samples (dict | list | DataFrame): Pandas ``DataFrame`` or compatible + ``list`` or ``dict`` containing the dataset information. + root (str | Path | None, optional): Root directory of the dataset. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split to load. + Choose from ``Split.FULL``, ``Split.TRAIN``, ``Split.TEST``. + Defaults to ``None``. + + Returns: + DataFrame: Dataset samples with columns for image paths, labels, splits + and mask paths (for segmentation). + + Examples: + Create a classification dataset: + >>> samples = { + ... "image_path": ["images/00.png", "images/01.png", "images/02.png", ... ], + ... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.NORMAL, ... ], + ... "split": [Split.TRAIN, Split.TRAIN, Split.TRAIN, ... ], + ... } + >>> tabular_df = make_tabular_dataset( + ... samples=samples, + ... root="./datasets/custom", + ... split=Split.TRAIN, + ... ) + >>> tabular_df.head() + image_path label label_index mask_path split + 0 ./datasets/custom/images/00.png DirType.NORMAL 0 Split.TRAIN + 1 ./datasets/custom/images/01.png DirType.NORMAL 0 Split.TRAIN + 2 ./datasets/custom/images/02.png DirType.NORMAL 0 Split.TRAIN + 3 ./datasets/custom/images/03.png DirType.NORMAL 0 Split.TRAIN + 4 ./datasets/custom/images/04.png DirType.NORMAL 0 Split.TRAIN + """ + ###################### + ### Pre-processing ### + ###################### + + # Convert to pandas DataFrame if dictionary or list is given + if isinstance(samples, dict | list): + samples = DataFrame(samples) + if "image_path" not in samples.columns: + msg = "The samples table must contain an 'image_path' column." + raise ValueError(msg) + samples = samples.sort_values(by="image_path", ignore_index=True) + + ########################### + ### Add missing columns ### + ########################### + + # Adding missing columns successively: + # The user can provide one or more of columns 'label_index', 'label', and 'split'. + # The missing columns will be inferred from the provided columns by predefined rules. + + if "label_index" in samples.columns: + samples.label_index = samples.label_index.astype("Int64") + + columns_present = [col in samples.columns for col in ["label_index", "label", "split"]] + + # all columns missing + if columns_present == [ + False, # label_index + False, # label + False, # split + ]: + msg = "The samples table must contain at least one of 'label_index', 'label' or 'split' columns." + raise ValueError(msg) + + # label_index missing (split can be present or missing, therefore only first two values are checked) + if columns_present[:2] == [ + False, # label_index + True, # label + ]: + label_to_label_index = { + DirType.ABNORMAL: LabelName.ABNORMAL, + DirType.NORMAL: LabelName.NORMAL, + DirType.NORMAL_TEST: LabelName.NORMAL, + } + samples["label_index"] = samples["label"].map(label_to_label_index).astype("Int64") + + # label_index and label missing + elif columns_present == [ + False, # label_index + False, # label + True, # split + ]: + split_to_label_index = { + Split.TRAIN: LabelName.NORMAL, + Split.TEST: LabelName.ABNORMAL, + } + samples["label_index"] = samples["split"].map(split_to_label_index).astype("Int64") + + # label and split missing + elif columns_present == [ + True, # label_index + False, # label + False, # split + ]: + label_index_to_label = { + LabelName.ABNORMAL: DirType.ABNORMAL, + LabelName.NORMAL: DirType.NORMAL, + } + samples["label"] = samples["label_index"].map(label_index_to_label) + + # reevaluate columns_present in case a column was added in the previous control flow + columns_present = [col in samples.columns for col in ["label_index", "label", "split"]] + # label missing + if columns_present == [ + True, # label_index + False, # label + True, # split + ]: + samples["label"] = samples.apply( + lambda x: DirType.NORMAL + if (x["label_index"] == LabelName.NORMAL) and (x["split"] == Split.TRAIN) + else ( + DirType.NORMAL_TEST + if x["label_index"] == LabelName.NORMAL and x["split"] == Split.TEST + else (DirType.ABNORMAL if x["label_index"] == LabelName.ABNORMAL else None) + ), + axis=1, + ) + # split missing + elif columns_present == [ + True, # label_index + True, # label + False, # split + ]: + label_to_split = { + DirType.NORMAL: Split.TRAIN, + DirType.ABNORMAL: Split.TEST, + DirType.NORMAL_TEST: Split.TEST, + } + samples["split"] = samples["label"].map(label_to_split) + + # Add mask_path column if not exists + if "mask_path" not in samples.columns: + samples["mask_path"] = "" + + ####################### + ### Post-processing ### + ####################### + + # Add root to paths + samples["mask_path"] = samples["mask_path"].fillna("") + if root: + samples["image_path"] = samples["image_path"].map(lambda x: Path(root, x)) + samples.loc[ + samples["mask_path"] != "", + "mask_path", + ] = samples.loc[samples["mask_path"] != "", "mask_path"].map(lambda x: Path(root, x)) + samples = samples.astype({"image_path": "str", "mask_path": "str", "label": "str"}) + + # Check if anomalous samples are in training set + if ((samples.label_index == LabelName.ABNORMAL) & (samples.split == Split.TRAIN)).any(): + msg = "Training set must not contain anomalous samples." + raise MisMatchError(msg) + + # Check for None or NaN values + if samples.isna().any().any(): + msg = "The samples table contains None or NaN values." + raise ValueError(msg) + + # Infer the task type + samples.attrs["task"] = "classification" if (samples["mask_path"] == "").all() else "segmentation" + + # Get the dataframe for the split. + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples diff --git a/tests/conftest.py b/tests/conftest.py index 5069f06378..30db7b1875 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,9 +56,9 @@ def dataset_path(project_path: Path) -> Path: # 1. Create the dummy image datasets. for data_format in list(ImageDataFormat): - # Do not generate a dummy dataset for folder datasets. - # We could use one of these datasets to test the folders datasets. - if not data_format.value.startswith("folder"): + # Do not generate a dummy dataset for folder and tabular datasets. + # We could use one of these datasets to test the folders and tabular datasets. + if not data_format.value.startswith(("folder", "tabular")): dataset_generator = DummyImageDatasetGenerator(data_format=data_format, root=_dataset_path) dataset_generator.generate_dataset() diff --git a/tests/helpers/data.py b/tests/helpers/data.py index c66f07e67d..aa71ab83e1 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -411,6 +411,10 @@ def _generate_dummy_folder_dataset(self) -> None: mask_filename = mask_dir / image_filename.name self.image_generator.generate_image(label, image_filename, mask_filename) + def _generate_dummy_tabular_dataset(self) -> None: + """Generate dummy folder structure for tabular dataset in a temporary directory.""" + self._generate_dummy_folder_dataset() + def _generate_dummy_btech_dataset(self) -> None: """Generate dummy BeanTech dataset in directory using the same convention as BeanTech AD.""" # BeanTech AD follows the same convention as MVTec AD. diff --git a/tests/unit/data/datamodule/image/test_tabular.py b/tests/unit/data/datamodule/image/test_tabular.py new file mode 100644 index 0000000000..757743987c --- /dev/null +++ b/tests/unit/data/datamodule/image/test_tabular.py @@ -0,0 +1,105 @@ +"""Unit Tests - Tabular Datamodule.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pandas as pd +import pytest +from torchvision.transforms.v2 import Resize + +from anomalib.data import Folder, Tabular +from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule + + +class TestTabular(_TestAnomalibImageDatamodule): + """Tabular Datamodule Unit Tests.""" + + @staticmethod + def get_samples_dataframe(dataset_path: Path) -> pd.DataFrame: + """Create samples DataFrame using the Folder datamodule.""" + _folder_datamodule = Folder( + name="dummy", + root=dataset_path / "mvtecad" / "dummy", + normal_dir="train/good", + abnormal_dir="test/bad", + normal_test_dir="test/good", + mask_dir="ground_truth/bad", + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + ) + _folder_datamodule.setup() + return pd.concat([ + _folder_datamodule.train_data.samples, + _folder_datamodule.test_data.samples, + _folder_datamodule.val_data.samples, + ]) + + @pytest.fixture( + params=[ + None, + ["label"], + ["label_index"], + ["split"], + ["mask_path"], + ], + ) + @staticmethod + def columns_to_drop(request: pytest.FixtureRequest) -> list[str] | None: + """Return the columns to be dropped from the samples dataframe.""" + return request.param + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path, columns_to_drop: list | None) -> Tabular: + """Create and return a Tabular datamodule.""" + _samples = TestTabular.get_samples_dataframe(dataset_path) + if columns_to_drop: + _samples = _samples.drop(columns_to_drop, axis="columns") + _datamodule = Tabular( + name="dummy", + samples=_samples, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + augmentations=Resize((256, 256)), + ) + _datamodule.setup() + return _datamodule + + @pytest.fixture() + @staticmethod + def fxt_data_config_path() -> str: + """Return the path to the test data config.""" + return "examples/configs/data/tabular.yaml" + + +class TestTabularFromFile(TestTabular): + """Tabular Datamodule Unit Tests for alternative constructor. + + Tests for the Datamodule creation from file. + """ + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path) -> Tabular: + """Create and return a Tabular datamodule.""" + _samples = TestTabular.get_samples_dataframe(dataset_path) + with tempfile.NamedTemporaryFile(suffix=".csv") as samples_file: + _samples.to_csv(samples_file) + samples_file.seek(0) + + _datamodule = Tabular.from_file( + name="dummy", + file_path=samples_file.name, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + augmentations=Resize((256, 256)), + ) + _datamodule.setup() + + return _datamodule