Skip to content

Commit f891253

Browse files
committed
Add CIFAR10 dataset and input/output shape getters.
1 parent fe59889 commit f891253

File tree

6 files changed

+176
-2
lines changed

6 files changed

+176
-2
lines changed

src/fflib/utils/data/cifar10.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
3+
from torch.nn.functional import one_hot
4+
from torch.utils.data import DataLoader, random_split
5+
from torchvision.datasets import CIFAR10 # type: ignore
6+
from torchvision.transforms import Compose, ToTensor, Lambda # type: ignore
7+
8+
from fflib.utils.data import FFDataProcessor
9+
from fflib.interfaces.iff import IFF
10+
11+
from enum import Enum
12+
from typing import Tuple, Dict, Callable, Any
13+
14+
15+
class NegativeGenerator(Enum):
16+
INVERSE = 1
17+
RANDOM = 2
18+
HIGHEST_INCORRECT = 3
19+
20+
21+
class FFCIFAR10(FFDataProcessor):
22+
def __init__(
23+
self,
24+
batch_size: int,
25+
validation_split: float | None,
26+
download: bool = True,
27+
path: str = "./data",
28+
image_transform: Callable[..., Any] = Compose([ToTensor(), Lambda(torch.flatten)]),
29+
train_kwargs: Dict[str, Any] = {},
30+
test_kwargs: Dict[str, Any] = {},
31+
negative_generator: NegativeGenerator = NegativeGenerator.INVERSE,
32+
use: float = 1.0,
33+
):
34+
35+
assert isinstance(batch_size, int)
36+
assert batch_size > 0
37+
self.batch_size = batch_size
38+
if "batch_size" not in train_kwargs:
39+
train_kwargs["batch_size"] = self.batch_size
40+
if "batch_size" not in test_kwargs:
41+
test_kwargs["batch_size"] = self.batch_size
42+
43+
train_kwargs["shuffle"] = True
44+
45+
assert use >= 0.0 and use <= 1.0
46+
47+
self.validation_split = validation_split
48+
self.download = download
49+
self.path = path
50+
self.image_transform = image_transform
51+
self.train_kwargs = train_kwargs
52+
self.test_kwargs = test_kwargs
53+
self.negative_generator = negative_generator
54+
self.use = use
55+
56+
self.train_dataset = CIFAR10(
57+
self.path, train=True, download=self.download, transform=self.image_transform
58+
)
59+
self.test_dataset = CIFAR10(
60+
self.path, train=False, download=self.download, transform=self.image_transform
61+
)
62+
self.test_loader = DataLoader(self.test_dataset, **self.test_kwargs)
63+
64+
dataset_size = len(self.train_dataset)
65+
used_dataset_size = int(dataset_size * self.use)
66+
not_used_dataset_size = dataset_size - used_dataset_size
67+
68+
# In case a validation split is given
69+
if self.validation_split:
70+
# Determine the sizes of training and validation sets
71+
val_size = int(self.validation_split * used_dataset_size)
72+
train_size = used_dataset_size - val_size
73+
74+
# Split dataset into train and validation sets
75+
train_dataset, val_dataset, _ = random_split(
76+
self.train_dataset, [train_size, val_size, not_used_dataset_size]
77+
)
78+
79+
# Create data loaders for train and validation
80+
self.train_loader = DataLoader(train_dataset, **self.train_kwargs)
81+
self.val_loader = DataLoader(val_dataset, **self.test_kwargs)
82+
83+
assert len(self.train_loader) + len(self.val_loader) <= used_dataset_size
84+
return
85+
86+
train_dataset, _ = random_split(
87+
self.train_dataset, [used_dataset_size, not_used_dataset_size]
88+
)
89+
self.train_loader = DataLoader(train_dataset, **self.train_kwargs)
90+
91+
def get_input_shape(self) -> torch.Size:
92+
return torch.Size((32 * 32 * 3,))
93+
94+
def get_output_shape(self) -> torch.Size:
95+
return torch.Size((10,))
96+
97+
def get_train_loader(self) -> DataLoader[Any]:
98+
return self.train_loader
99+
100+
def get_val_loader(self) -> DataLoader[Any]:
101+
return self.val_loader
102+
103+
def get_test_loader(self) -> DataLoader[Any]:
104+
return self.test_loader
105+
106+
def get_all_loaders(self) -> Dict[str, DataLoader[Any]]:
107+
return {
108+
"train": self.get_train_loader(),
109+
"val": self.get_val_loader(),
110+
"test": self.get_test_loader(),
111+
}
112+
113+
def encode_output(self, y: torch.Tensor) -> torch.Tensor:
114+
return one_hot(y, num_classes=10).float()
115+
116+
def combine_to_input(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
117+
return torch.cat((x, y), 1)
118+
119+
def generate_negative(
120+
self,
121+
x: torch.Tensor,
122+
y: torch.Tensor,
123+
net: IFF,
124+
) -> Tuple[torch.Tensor, torch.Tensor]:
125+
126+
if self.negative_generator == NegativeGenerator.HIGHEST_INCORRECT:
127+
raise NotImplementedError()
128+
129+
if self.negative_generator == NegativeGenerator.INVERSE:
130+
y_hot = 1 - one_hot(y, num_classes=10).float()
131+
return x, y_hot
132+
133+
rnd = torch.rand((x.shape[0], 10), device=x.device)
134+
rnd[torch.arange(x.shape[0]), y] = 0
135+
y_new = rnd.argmax(1)
136+
y_hot = one_hot(y_new, num_classes=10).float()
137+
return x, y_hot

src/fflib/utils/data/dataprocessor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ def check_splits(splits: Tuple[float, float, float] | Tuple[float, float]) -> No
1515
assert all(0 <= s <= 1 for s in splits)
1616
assert len(splits) in [2, 3]
1717

18+
@abstractmethod
19+
def get_input_shape(self) -> torch.Size:
20+
pass
21+
22+
@abstractmethod
23+
def get_output_shape(self) -> torch.Size:
24+
pass
25+
26+
def numel(self) -> int:
27+
return self.get_input_shape().numel() + self.get_output_shape().numel()
28+
1829
@abstractmethod
1930
def get_train_loader(self) -> DataLoader[Any]:
2031
pass

src/fflib/utils/data/datasets.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from fflib.utils.data.mnist import FFMNIST
22
from fflib.utils.data.fashion_mnist import FFFashionMNIST
33
from fflib.utils.data.xor import FFXOR
4+
from fflib.utils.data.cifar10 import FFCIFAR10
45

5-
from typing import Dict, Any
6+
from typing import Any
67

78

89
def CreateDatasetFromName(
@@ -11,7 +12,7 @@ def CreateDatasetFromName(
1112
validation_split: float,
1213
use: float = 1.0,
1314
**kwargs: Any,
14-
) -> FFMNIST | FFFashionMNIST | FFXOR | None:
15+
) -> FFMNIST | FFFashionMNIST | FFXOR | FFCIFAR10 | None:
1516
"""Create a Dataset object via a name of the dataset.
1617
This is used to dynamically set the dataset from CLI arguments.
1718
@@ -41,6 +42,13 @@ def CreateDatasetFromName(
4142
use=use,
4243
**kwargs,
4344
)
45+
elif name == "cifar10":
46+
return FFCIFAR10(
47+
batch_size,
48+
validation_split,
49+
use=use,
50+
**kwargs,
51+
)
4452
elif name == "xor":
4553
return FFXOR(
4654
batch_size,

src/fflib/utils/data/fashion_mnist.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def __init__(
8888
)
8989
self.train_loader = DataLoader(train_dataset, **self.train_kwargs)
9090

91+
def get_input_shape(self) -> torch.Size:
92+
return torch.Size((28 * 28,))
93+
94+
def get_output_shape(self) -> torch.Size:
95+
return torch.Size((10,))
96+
9197
def get_train_loader(self) -> DataLoader[Any]:
9298
return self.train_loader
9399

src/fflib/utils/data/mnist.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def __init__(
8888
)
8989
self.train_loader = DataLoader(train_dataset, **self.train_kwargs)
9090

91+
def get_input_shape(self) -> torch.Size:
92+
return torch.Size((28 * 28,))
93+
94+
def get_output_shape(self) -> torch.Size:
95+
return torch.Size((10,))
96+
9197
def get_train_loader(self) -> DataLoader[Any]:
9298
return self.train_loader
9399

src/fflib/utils/data/xor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def __init__(
4848
self.test_loader = DataLoader(XORDataset(size), **self.test_kwargs)
4949
self.val_loader = DataLoader(XORDataset(size), **self.test_kwargs)
5050

51+
def get_input_shape(self) -> torch.Size:
52+
return torch.Size((2,))
53+
54+
def get_output_shape(self) -> torch.Size:
55+
return torch.Size((2,))
56+
5157
def get_train_loader(self) -> DataLoader[Any]:
5258
return self.train_loader
5359

0 commit comments

Comments
 (0)