Skip to content

Commit 9c1523b

Browse files
committed
Add label encoding in MiniImagenet
1 parent 19d1be7 commit 9c1523b

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

datasets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,25 @@ def __init__(self, root, train=False, valid=False, test=False,
5858
next(reader) # Skip the header
5959
for line in reader:
6060
self._data.append(tuple(line))
61+
self._fit_label_encoding()
6162

6263
def __getitem__(self, index):
6364
filename, label = self._data[index]
6465
image = pil_loader(os.path.join(self.image_folder, filename))
66+
label = self._label_encoder[label]
6567
if self.transform is not None:
6668
image = self.transform(image)
6769
if self.target_transform is not None:
6870
label = self.target_transform(label)
6971

7072
return image, label
7173

74+
def _fit_label_encoding(self):
75+
_, labels = zip(*self._data)
76+
unique_labels = set(labels)
77+
self._label_encoder = dict((label, idx)
78+
for (idx, label) in enumerate(unique_labels))
79+
7280
def _check_exists(self):
7381
return (os.path.exists(self.image_folder)
7482
and os.path.exists(self.split_filename))

prior_miniimagenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def main(args):
102102
best_loss = -1.
103103
for epoch in range(args.num_epochs):
104104
train(train_loader, model, prior, optimizer, args, writer)
105-
loss, _ = test(valid_loader, model, prior, args, writer)
105+
loss = test(valid_loader, model, prior, args, writer)
106106

107107
# reconstruction = generate_samples(fixed_images, model, args)
108108
# grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)

0 commit comments

Comments
 (0)