Skip to content

Commit 1bc7aa9

Browse files
authored
Refactor profile (#251)
* refactor(Profiles): change searchspace_type to trainer_preset * fix(scripts): fix the profiles of all scripts following refactor refactor changes searchspace_type to trainer_preset
1 parent 34cf77a commit 1bc7aa9

20 files changed

+108
-94
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ We define modular, differentiable NAS components within our library. Below is a
4949
```python
5050
from confopt.profile import DARTSProfile
5151
from confopt.train import Experiment
52-
from confopt.enums import SearchSpaceType, DatasetType
52+
from confopt.enums import TrainerPresetType, SearchSpaceType, DatasetType
5353

5454
profile = DARTSProfile(
55-
searchspace_type=SearchSpaceType.DARTS,
55+
trainer_preset=TrainerPresetType.DARTS,
5656
epochs=3
5757
)
5858

docs/source/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ Below is a snippet that demonstrates how we run a vanilla-DARTS experiment.
5353
```python
5454
from confopt.profile import DARTSProfile
5555
from confopt.train import Experiment
56-
from confopt.enums import SearchSpaceType, DatasetType
56+
from confopt.enums import TrainerPresetType, SearchSpaceType, DatasetType
5757

5858
profile = DARTSProfile(
59-
searchspace_type=SearchSpaceType.DARTS,
59+
trainer_preset=TrainerPresetType.DARTS,
6060
epochs=3
6161
)
6262

examples/demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from confopt.profile import DARTSProfile, DRNASProfile, GDASProfile, ReinMaxProfile
44
from confopt.train import Experiment
5-
from confopt.enums import DatasetType, SearchSpaceType
5+
from confopt.enums import DatasetType, TrainerPresetType, SearchSpaceType
66

77
if __name__ == "__main__":
88
searchspace = SearchSpaceType.DARTS
@@ -19,7 +19,7 @@
1919
}
2020

2121
profile = DRNASProfile(
22-
searchspace_type=searchspace,
22+
trainer_preset=TrainerPresetType.DARTS,
2323
epochs=10,
2424
oles=True,
2525
calc_gm_score=True,

examples/demo_advanced.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from confopt.profile import GDASProfile
44
from confopt.train import Experiment
5-
from confopt.enums import DatasetType, SearchSpaceType
5+
from confopt.enums import DatasetType, TrainerPresetType, SearchSpaceType
66

77
if __name__ == "__main__":
88
search_space = SearchSpaceType.DARTS
99

1010
profile = GDASProfile(
11-
searchspace_type=search_space,
11+
trainer_preset=TrainerPresetType.DARTS,
1212
epochs=10,
1313
perturbation="random",
1414
entangle_op_weights=True,

examples/demo_light.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from confopt.profile import GDASProfile
44
from confopt.train import Experiment
5-
from confopt.enums import SearchSpaceType, DatasetType
5+
from confopt.enums import SearchSpaceType, TrainerPresetType, DatasetType
66

77
if __name__ == "__main__":
88
profile = GDASProfile(
9-
searchspace_type=SearchSpaceType.DARTS,
9+
trainer_preset=TrainerPresetType.DARTS,
1010
epochs=3,
1111
)
1212

examples/demo_taskonomy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if __name__ == "__main__":
99
domain = "class_object"
1010
profile = GDASProfile(
11-
searchspace_type="tnb101", epochs=3, searchspace_domain=domain
11+
trainer_preset="tnb101", epochs=3, searchspace_domain=domain
1212
)
1313
profile.configure_searchspace(num_classes=get_num_classes("taskonomy", domain))
1414
experiment = Experiment(

examples/example_synthetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_profile(args: argparse.Namespace) -> Callable: # type: ignore
7878
searchspace = SearchSpaceType.BABYDARTS
7979
dataset = DatasetType.SYNTHETIC
8080

81-
profile = get_profile(args)(searchspace_type=searchspace, epochs=args.search_epochs)
81+
profile = get_profile(args)(trainer_preset=searchspace.value, epochs=args.search_epochs)
8282

8383
profile.configure_synthetic_dataset(
8484
signal_width=args.signal_width,

examples/examples_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
seed = 100
1313

1414
profile = DARTSProfile(
15-
searchspace_type=searchspace,
15+
trainer_preset="darts",
1616
is_partial_connection=True,
1717
perturbation="random",
1818
sampler_sample_frequency="step",

examples/experiment_drnas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from confopt.profile import DiscreteProfile, DRNASProfile
99
from confopt.train import Experiment
10-
from confopt.enums import DatasetType, SearchSpaceType
10+
from confopt.enums import DatasetType, SearchSpaceType, TrainerPresetType
1111

1212
dataset_size = {
1313
"cifar10": 10,
@@ -65,7 +65,7 @@ def read_args() -> argparse.Namespace:
6565

6666
# Sampler and Perturbator have different sample_frequency
6767
profile = DRNASProfile(
68-
searchspace_type=searchspace,
68+
trainer_preset=TrainerPresetType(args.searchspace),
6969
is_partial_connection=args.searchspace == "darts",
7070
epochs=args.search_epochs,
7171
sampler_sample_frequency="step",
@@ -95,7 +95,7 @@ def read_args() -> argparse.Namespace:
9595
}
9696
profile.configure_trainer(**train_config)
9797
discrete_profile = DiscreteProfile(
98-
searchspace_type=searchspace, epochs=args.eval_epochs, train_portion=0.9
98+
trainer_preset=args.searchspace, epochs=args.eval_epochs, train_portion=0.9
9999
)
100100
discrete_profile.configure_trainer(batch_size=64)
101101

examples/experiment_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from confopt.profile import GDASProfile
44
from confopt.train import Experiment
5-
from confopt.enums import DatasetType, SearchSpaceType
5+
from confopt.enums import DatasetType, SearchSpaceType, TrainerPresetType
66

77
if __name__ == "__main__":
88
searchspace = SearchSpaceType("nb201")
@@ -11,7 +11,7 @@
1111

1212
# Sampler and Perturbator have different sample_frequency
1313
profile = GDASProfile(
14-
searchspace_type=searchspace,
14+
trainer_preset=TrainerPresetType("nb201"),
1515
is_partial_connection=True,
1616
perturbation="random",
1717
sampler_sample_frequency="step",

0 commit comments

Comments
 (0)