Skip to content

Commit b90fb9f

Browse files
committed
added first draft of tabular datamodule
Signed-off-by: Manuel Konrad <84141230+manuelkonrad@users.noreply.github.com>
1 parent e6da810 commit b90fb9f

File tree

13 files changed

+671
-2
lines changed

13 files changed

+671
-2
lines changed

docs/source/markdown/guides/reference/data/datamodules/image.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ Dataset format compatible with Intel Geti™.
2828
Custom folder-based dataset organization.
2929
:::
3030
31+
:::{grid-item-card} Tabular
32+
:link: anomalib.data.datamodules.image.Tabular
33+
:link-type: doc
34+
35+
Custom tabular dataset.
36+
:::
37+
3138
:::{grid-item-card} Kolektor
3239
:link: anomalib.data.datamodules.image.Kolektor
3340
:link-type: doc
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Tabular Datamodule
2+
3+
```{eval-rst}
4+
.. automodule:: anomalib.data.datamodules.image.tabular
5+
:members:
6+
:show-inheritance:
7+
```

examples/configs/data/tabular.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
class_path: anomalib.data.Tabular
2+
init_args:
3+
name: bottle
4+
root: "datasets/MVTecAD/bottle"
5+
train_batch_size: 32
6+
eval_batch_size: 32
7+
num_workers: 8
8+
test_split_mode: from_dir
9+
test_split_ratio: 0.2
10+
val_split_mode: same_as_test
11+
val_split_ratio: 0.5
12+
seed: null
13+
samples:
14+
- image_path: train/good/000.png
15+
label_index: 0
16+
mask_path: ""
17+
split: train
18+
- image_path: train/good/001.png
19+
label_index: 0
20+
mask_path: ""
21+
split: train
22+
- image_path: train/good/002.png
23+
label_index: 0
24+
mask_path: ""
25+
split: train
26+
- image_path: train/good/003.png
27+
label_index: 0
28+
mask_path: ""
29+
split: train
30+
- image_path: train/good/004.png
31+
label_index: 0
32+
mask_path: ""
33+
split: train
34+
- image_path: test/broken_large/000.png
35+
label_index: 1
36+
mask_path: ground_truth/broken_large/000_mask.png
37+
split: test
38+
- image_path: test/broken_large/002.png
39+
label_index: 1
40+
mask_path: ground_truth/broken_large/002_mask.png
41+
split: test
42+
- image_path: test/broken_large/004.png
43+
label_index: 1
44+
mask_path: ground_truth/broken_large/004_mask.png
45+
split: test
46+
- image_path: test/good/000.png
47+
label_index: 0
48+
mask_path: ""
49+
split: test
50+
- image_path: test/good/001.png
51+
label_index: 0
52+
mask_path: ""
53+
split: test
54+
- image_path: test/good/003.png
55+
label_index: 0
56+
mask_path: ""
57+
split: test
58+
- image_path: test/broken_large/001.png
59+
label_index: 1
60+
mask_path: ground_truth/broken_large/001_mask.png
61+
split: test
62+
- image_path: test/broken_large/003.png
63+
label_index: 1
64+
mask_path: ground_truth/broken_large/003_mask.png
65+
split: test
66+
- image_path: test/good/002.png
67+
label_index: 0
68+
mask_path: ""
69+
split: test
70+
- image_path: test/good/004.png
71+
label_index: 0
72+
mask_path: ""
73+
split: test

src/anomalib/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
MVTecAD2,
6262
MVTecLOCO,
6363
RealIAD,
64+
Tabular,
6465
Visa,
6566
)
6667
from .datamodules.video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat
@@ -75,6 +76,7 @@
7576
KolektorDataset,
7677
MVTecADDataset,
7778
MVTecLOCODataset,
79+
TabularDataset,
7880
VADDataset,
7981
VisaDataset,
8082
)
@@ -177,6 +179,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule
177179
"MVTecAD2",
178180
"MVTecLOCO",
179181
"RealIAD",
182+
"Tabular",
180183
"VAD",
181184
"Visa",
182185
# Video Data Modules
@@ -192,6 +195,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule
192195
"KolektorDataset",
193196
"MVTecADDataset",
194197
"MVTecLOCODataset",
198+
"TabularDataset",
195199
"VADDataset",
196200
"VisaDataset",
197201
"AvenueDataset",

src/anomalib/data/datamodules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
from .depth import Folder3D, MVTec3D
7-
from .image import VAD, BTech, Datumaro, Folder, Kolektor, MVTec, MVTecAD, Visa
7+
from .image import VAD, BTech, Datumaro, Folder, Kolektor, MVTec, MVTecAD, Tabular, Visa
88
from .video import Avenue, ShanghaiTech, UCSDped
99

1010
__all__ = [
@@ -16,6 +16,7 @@
1616
"Kolektor",
1717
"MVTec", # Include MVTec for backward compatibility
1818
"MVTecAD",
19+
"Tabular",
1920
"VAD",
2021
"Visa",
2122
"Avenue",

src/anomalib/data/datamodules/image/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- ``MVTecAD``: MVTec Anomaly Detection Dataset
1111
- ``MVTecAD2``: MVTec Anomaly Detection Dataset 2
1212
- ``MVTecLOCO``: MVTec LOCO Dataset with logical and structural anomalies
13+
- ``Tabular``: Custom tabular dataset with image paths and labels
1314
- ``VAD``: Valeo Anomaly Detection Dataset
1415
- ``Visa``: Visual Anomaly Dataset
1516
@@ -36,6 +37,7 @@
3637
from .mvtecad import MVTec, MVTecAD
3738
from .mvtecad2 import MVTecAD2
3839
from .realiad import RealIAD
40+
from .tabular import Tabular
3941
from .vad import VAD
4042
from .visa import Visa
4143

@@ -54,6 +56,7 @@ class ImageDataFormat(str, Enum):
5456
- ``MVTEC_AD_2``: MVTec AD 2 Dataset
5557
- ``MVTEC_3D``: MVTec 3D AD Dataset
5658
- ``MVTEC_LOCO``: MVTec LOCO Dataset
59+
- ``TABULAR``: Custom Tabular Dataset
5760
- ``REALIAD``: Real-IAD Dataset
5861
- ``VAD``: Valeo Anomaly Detection Dataset
5962
- ``VISA``: Visual Anomaly Dataset
@@ -69,6 +72,7 @@ class ImageDataFormat(str, Enum):
6972
MVTEC_3D = "mvtec_3d"
7073
MVTEC_LOCO = "mvtec_loco"
7174
REAL_IAD = "realiad"
75+
TABULAR = "tabular"
7276
VAD = "vad"
7377
VISA = "visa"
7478

@@ -83,6 +87,7 @@ class ImageDataFormat(str, Enum):
8387
"MVTecAD2",
8488
"MVTecLOCO",
8589
"RealIAD",
90+
"Tabular",
8691
"VAD",
8792
"Visa",
8893
]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Custom Tabular Data Module.
2+
3+
This script creates a custom Lightning DataModule from a table or tabular file
4+
containing image paths and labels.
5+
6+
Example:
7+
Create a Tabular datamodule::
8+
9+
>>> from anomalib.data import Tabular
10+
>>> samples = {
11+
... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ],
12+
... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ],
13+
... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ],
14+
... }
15+
>>> datamodule = Tabular(
16+
... name="custom",
17+
... samples=samples,
18+
... root="./datasets/custom",
19+
... )
20+
"""
21+
22+
# Copyright (C) 2025 Intel Corporation
23+
# SPDX-License-Identifier: Apache-2.0
24+
25+
from pathlib import Path
26+
from typing import IO
27+
28+
import pandas as pd
29+
from torchvision.transforms.v2 import Transform
30+
31+
from anomalib.data.datamodules.base.image import AnomalibDataModule
32+
from anomalib.data.datasets.image.tabular import TabularDataset
33+
from anomalib.data.utils import Split, TestSplitMode, ValSplitMode
34+
35+
36+
class Tabular(AnomalibDataModule):
37+
"""Tabular DataModule.
38+
39+
Args:
40+
name (str): Name of the dataset. Used for logging/saving.
41+
samples (dict | list | DataFrame): Pandas ``DataFrame`` or compatible ``list``
42+
or ``dict`` containing the dataset information.
43+
root (str | Path | None): Root folder containing normal and abnormal
44+
directories. Defaults to ``None``.
45+
normal_split_ratio (float): Ratio to split normal training images for
46+
test set when no normal test images exist.
47+
Defaults to ``0.2``.
48+
train_batch_size (int): Training batch size.
49+
Defaults to ``32``.
50+
eval_batch_size (int): Validation/test batch size.
51+
Defaults to ``32``.
52+
num_workers (int): Number of workers for data loading.
53+
Defaults to ``8``.
54+
train_augmentations (Transform | None): Augmentations to apply dto the training images
55+
Defaults to ``None``.
56+
val_augmentations (Transform | None): Augmentations to apply to the validation images.
57+
Defaults to ``None``.
58+
test_augmentations (Transform | None): Augmentations to apply to the test images.
59+
Defaults to ``None``.
60+
augmentations (Transform | None): General augmentations to apply if stage-specific
61+
augmentations are not provided.
62+
test_split_mode (TestSplitMode): Method to obtain test subset.
63+
Defaults to ``TestSplitMode.FROM_DIR``.
64+
test_split_ratio (float): Fraction of train images for testing.
65+
Defaults to ``0.2``.
66+
val_split_mode (ValSplitMode): Method to obtain validation subset.
67+
Defaults to ``ValSplitMode.FROM_TEST``.
68+
val_split_ratio (float): Fraction of images for validation.
69+
Defaults to ``0.5``.
70+
seed (int | None): Random seed for splitting.
71+
Defaults to ``None``.
72+
73+
Example:
74+
Create and setup a tabular datamodule::
75+
76+
>>> from anomalib.data import Tabular
77+
>>> samples = {
78+
... "image_path": ["images/image1.png", "images/image2.png", "images/image3.png", ... ],
79+
... "label_index": [LabelName.NORMAL, LabelName.NORMAL, LabelName.ABNORMAL, ... ],
80+
... "split": [Split.TRAIN, Split.TRAIN, Split.TEST, ... ],
81+
... }
82+
>>> datamodule = Tabular(
83+
... name="custom",
84+
... samples=samples,
85+
... root="./datasets/custom",
86+
... )
87+
>>> datamodule.setup()
88+
89+
Get a batch from train dataloader::
90+
91+
>>> batch = next(iter(datamodule.train_dataloader()))
92+
>>> batch.keys()
93+
dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path'])
94+
95+
Get a batch from test dataloader::
96+
97+
>>> batch = next(iter(datamodule.test_dataloader()))
98+
>>> batch.keys()
99+
dict_keys(['image', 'label', 'mask', 'image_path', 'mask_path'])
100+
"""
101+
102+
def __init__(
103+
self,
104+
name: str,
105+
samples: dict | list | pd.DataFrame,
106+
root: str | Path | None = None,
107+
normal_split_ratio: float = 0.2,
108+
train_batch_size: int = 32,
109+
eval_batch_size: int = 32,
110+
num_workers: int = 8,
111+
train_augmentations: Transform | None = None,
112+
val_augmentations: Transform | None = None,
113+
test_augmentations: Transform | None = None,
114+
augmentations: Transform | None = None,
115+
test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR,
116+
test_split_ratio: float = 0.2,
117+
val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST,
118+
val_split_ratio: float = 0.5,
119+
seed: int | None = None,
120+
) -> None:
121+
self._name = name
122+
self.root = root
123+
self._unprocessed_samples = samples
124+
test_split_mode = TestSplitMode(test_split_mode)
125+
val_split_mode = ValSplitMode(val_split_mode)
126+
super().__init__(
127+
train_batch_size=train_batch_size,
128+
eval_batch_size=eval_batch_size,
129+
num_workers=num_workers,
130+
train_augmentations=train_augmentations,
131+
val_augmentations=val_augmentations,
132+
test_augmentations=test_augmentations,
133+
augmentations=augmentations,
134+
test_split_mode=test_split_mode,
135+
test_split_ratio=test_split_ratio,
136+
val_split_mode=val_split_mode,
137+
val_split_ratio=val_split_ratio,
138+
seed=seed,
139+
)
140+
141+
self.normal_split_ratio = normal_split_ratio
142+
143+
def _setup(self, _stage: str | None = None) -> None:
144+
self.train_data = TabularDataset(
145+
name=self.name,
146+
samples=self._unprocessed_samples,
147+
split=Split.TRAIN,
148+
root=self.root,
149+
)
150+
151+
self.test_data = TabularDataset(
152+
name=self.name,
153+
samples=self._unprocessed_samples,
154+
split=Split.TEST,
155+
root=self.root,
156+
)
157+
158+
@property
159+
def name(self) -> str:
160+
"""Get name of the datamodule.
161+
162+
Returns:
163+
Name of the datamodule.
164+
"""
165+
return self._name
166+
167+
@classmethod
168+
def from_file(
169+
cls: type["Tabular"],
170+
name: str,
171+
file_path: str | Path | IO[str] | IO[bytes],
172+
file_format: str = "csv",
173+
pd_kwargs: dict | None = None,
174+
**kwargs,
175+
) -> "Tabular":
176+
"""Create Tabular Datamodule from file.
177+
178+
Args:
179+
name (str): Name of the dataset. This is used to name the datamodule,
180+
especially when logging/saving.
181+
file_path (str | Path | file-like): Path or file-like object to tabular
182+
file containing the datset information.
183+
file_format (str): File format supported by a pd.read_* method, such
184+
as ``csv``, ``parquet`` or ``json``.
185+
Defaults to ``csv``.
186+
pd_kwargs (dict | None): Keyword argument dictionary for the pd.read_* method.
187+
Defaults to ``None``.
188+
kwargs (dict): Additional keyword arguments for the Tabular Datamodule class.
189+
190+
Returns:
191+
Tabular: Tabular Datamodule
192+
"""
193+
pd_kwargs = pd_kwargs or {}
194+
samples = getattr(pd, f"read_{file_format}")(file_path, **pd_kwargs)
195+
return cls(name, samples, **kwargs)

0 commit comments

Comments
 (0)