Skip to content

Commit ea7e6bf

Browse files
committed
fix bugs with cifar100
1 parent 489f3b6 commit ea7e6bf

File tree

6 files changed

+25
-11
lines changed

6 files changed

+25
-11
lines changed

code/dependencies/darts_classification_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ def __init__(
99
learning_rate: float = 0.001,
1010
weight_decay: float = 0.,
1111
auxiliary_loss_weight: float = 0.4,
12-
max_epochs: int = 600
12+
max_epochs: int = 600,
13+
num_classes: int = 10
1314
):
1415
self.auxiliary_loss_weight = auxiliary_loss_weight
1516
# Training length will be used in LR scheduler
1617
self.max_epochs = max_epochs
17-
super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, export_onnx=False, num_classes=10)
18+
super().__init__(learning_rate=learning_rate, weight_decay=weight_decay, export_onnx=False, num_classes=num_classes)
1819

1920
def configure_optimizers(self):
2021
"""Customized optimizer with momentum, as well as a scheduler."""

code/output/ensemble_results.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Ensemble Top-1 Accuracy: 10.01%
2-
Ensemble ECE: 0.0114
1+
Ensemble Top-1 Accuracy: 0.96%
2+
Ensemble ECE: 0.0015
33
Number of models: 2
4-
Model 1 Accuracy: 9.90%
5-
Model 2 Accuracy: 10.02%
4+
Model 1 Accuracy: 1.00%
5+
Model 2 Accuracy: 0.90%

code/surrogate_hp.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"dataset_path": "third_dataset/",
55
"device": "cpu",
66
"developer_mode": true,
7-
"n_models": 1300,
7+
"n_models": 300,
88

99
"upper_margin": 0.75,
1010
"lower_margin": 0.25,
@@ -36,14 +36,14 @@
3636
"n_models_in_pool": 128,
3737
"n_models_to_generate": 4096,
3838
"batch_size_inference": 4096,
39-
"min_accuracy_for_pool": 0.85,
39+
"min_accuracy_for_pool": 0.01,
4040
"plot_tsne": false,
4141
"best_models_save_path": "best_models/",
4242

4343
"n_epochs_final": 1,
4444
"lr_final": 0.025,
4545
"batch_size_final": 96,
46-
"dataset_name": "CIFAR10",
46+
"dataset_name": "FashionMNIST",
4747
"final_dataset_path": "final_dataset/",
4848
"output_path": "output/",
4949
"width": 4,
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

code/train_models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,13 @@ def get_data_loaders(self, batch_size: int = None) -> Tuple[DataLoader, DataLoad
101101

102102
if self.config.dataset_name.lower() == "cifar10":
103103
dataset_cls = CIFAR10
104+
self.num_classes = 10
104105
elif self.config.dataset_name.lower() == "cifar100":
105106
dataset_cls = CIFAR100
107+
self.num_classes = 100
106108
elif self.config.dataset_name.lower() == "fashionmnist":
107109
dataset_cls = FashionMNIST
110+
self.num_classes = 10
108111

109112
train_data = nni.trace(dataset_cls)(
110113
root=self.config.final_dataset_path,
@@ -158,11 +161,20 @@ def train_model(
158161
Train a single model defined by architecture and return the trained model.
159162
"""
160163
try:
164+
if self.config.dataset_name.lower() == "cifar10":
165+
dataset = "cifar"
166+
elif self.config.dataset_name.lower() == "cifar100":
167+
dataset = "cifar100"
168+
elif self.config.dataset_name.lower() == "fashionmnist":
169+
dataset = "cifar"
170+
else:
171+
raise ValueError(f"Unknown dataset: {self.config.dataset_name}")
172+
161173
with model_context(architecture):
162174
model = DartsSpace(
163175
width=self.config.width,
164176
num_cells=self.config.num_cells,
165-
dataset="cifar",
177+
dataset=dataset,
166178
)
167179

168180
model.to(self.device)
@@ -176,6 +188,7 @@ def train_model(
176188
weight_decay=3e-4,
177189
auxiliary_loss_weight=0.4,
178190
max_epochs=self.config.n_epochs_final,
191+
num_classes=self.num_classes,
179192
),
180193
trainer=Trainer(
181194
gradient_clip_val=5.0,
@@ -289,7 +302,7 @@ def evaluate_ensemble(self, test_loader):
289302
avg_output /= len(valid_models)
290303
ensemble_probs = avg_output.softmax(dim=1)
291304
confidences, preds_ens = ensemble_probs.max(1)
292-
correct_ens_batch = (preds_ens == labels)
305+
correct_ens_batch = preds_ens == labels
293306
correct_ensemble += correct_ens_batch.sum().item()
294307

295308
confidences = confidences.cpu().float()

0 commit comments

Comments
 (0)