Skip to content

Commit 34cf77a

Browse files
authored
Composite Sampler (#250)
* feat(CompositeSampler): add composite sampler to confopt * feat(CompositeProfile): add composite profile integrate composite sampler to confopt * test(test_sampler): add tests for composite sampler * refactor(CompositeProfile): update assertion message
1 parent 9945bf5 commit 34cf77a

File tree

8 files changed

+282
-13
lines changed

8 files changed

+282
-13
lines changed

src/confopt/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class SamplerType(Enum):
2121
GDAS = "gdas"
2222
SNAS = "snas"
2323
REINMAX = "reinmax"
24+
COMPOSITE = "composite"
2425

2526
def __str__(self) -> str:
2627
return self.value

src/confopt/oneshot/archsampler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .base_sampler import BaseSampler
2+
from .composite_sampler import CompositeSampler
23
from .darts.sampler import DARTSSampler
34
from .drnas.sampler import DRNASSampler
45
from .gdas.sampler import GDASSampler
@@ -12,4 +13,5 @@
1213
"GDASSampler",
1314
"SNASSampler",
1415
"ReinMaxSampler",
16+
"CompositeSampler",
1517
]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
from confopt.oneshot.archsampler import BaseSampler
6+
from confopt.oneshot.base import OneShotComponent
7+
8+
9+
class CompositeSampler(OneShotComponent):
10+
def __init__(
11+
self,
12+
arch_samplers: list[BaseSampler],
13+
arch_parameters: list[torch.Tensor],
14+
) -> None:
15+
super().__init__()
16+
self.arch_samplers = arch_samplers
17+
self.arch_parameters = arch_parameters
18+
19+
# get sample frequency from the samplers
20+
self.sample_frequency = arch_samplers[0].sample_frequency
21+
for sampler in arch_samplers:
22+
assert (
23+
self.sample_frequency == sampler.sample_frequency
24+
), "All samplers must have the same sample frequency"
25+
26+
def sample(self, alpha: torch.Tensor) -> torch.Tensor:
27+
sampled_alphas = alpha
28+
for sampler in self.arch_samplers:
29+
sampled_alphas = sampler.sample(sampled_alphas)
30+
31+
return sampled_alphas
32+
33+
def new_epoch(self) -> None:
34+
super().new_epoch()
35+
for sampler in self.arch_samplers:
36+
sampler.new_epoch()
37+
38+
def new_step(self) -> None:
39+
super().new_step()
40+
for sampler in self.arch_samplers:
41+
sampler.new_step()

src/confopt/profile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BaseProfile
22
from .profiles import (
3+
CompositeProfile,
34
DARTSProfile,
45
DiscreteProfile,
56
DRNASProfile,
@@ -16,4 +17,5 @@
1617
"SNASProfile",
1718
"DiscreteProfile",
1819
"ReinMaxProfile",
20+
"CompositeProfile",
1921
]

src/confopt/profile/profiles.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from abc import ABC
44
from typing import Any
55

6+
from typing_extensions import override
7+
68
from confopt.enums import SamplerType, SearchSpaceType
79
from confopt.searchspace.darts.core.genotypes import DARTSGenotype
810
from confopt.utils import get_num_classes
@@ -402,6 +404,147 @@ def configure_sampler(self, **kwargs) -> None: # type: ignore
402404
super().configure_sampler(**kwargs)
403405

404406

407+
class CompositeProfile(BaseProfile, ABC):
408+
SAMPLER_TYPE = SamplerType.COMPOSITE
409+
410+
def __init__(
411+
self,
412+
searchspace_type: str | SearchSpaceType,
413+
samplers: list[str | SamplerType],
414+
epochs: int,
415+
# GDAS configs
416+
tau_min: float = 0.1,
417+
tau_max: float = 10,
418+
# SNAS configs
419+
temp_init: float = 1.0,
420+
temp_min: float = 0.03,
421+
temp_annealing: bool = True,
422+
**kwargs: Any,
423+
) -> None:
424+
self.samplers = []
425+
for sampler in samplers:
426+
if isinstance(sampler, str):
427+
sampler = SamplerType(sampler)
428+
self.samplers.append(sampler)
429+
430+
self.tau_min = tau_min
431+
self.tau_max = tau_max
432+
433+
self.temp_init = temp_init
434+
self.temp_min = temp_min
435+
self.temp_annealing = temp_annealing
436+
437+
super().__init__( # type: ignore
438+
self.SAMPLER_TYPE,
439+
searchspace_type,
440+
epochs,
441+
**kwargs,
442+
)
443+
444+
def _initialize_sampler_config(self) -> None:
445+
"""Initializes the sampler configuration for Composite samplers.
446+
447+
The sampler configuration includes the sample frequency and the architecture
448+
combine function.
449+
450+
Args:
451+
None
452+
453+
Returns:
454+
None
455+
"""
456+
self.sampler_config: dict[int, dict] = {} # type: ignore
457+
for i, sampler in enumerate(self.samplers):
458+
if sampler in [SamplerType.DARTS, SamplerType.DRNAS]:
459+
config = {
460+
"sampler_type": sampler,
461+
"sample_frequency": self.sampler_sample_frequency,
462+
"arch_combine_fn": self.sampler_arch_combine_fn,
463+
}
464+
elif sampler in [SamplerType.GDAS, SamplerType.REINMAX]:
465+
config = {
466+
"sampler_type": sampler,
467+
"sample_frequency": self.sampler_sample_frequency,
468+
"arch_combine_fn": self.sampler_arch_combine_fn,
469+
"tau_min": self.tau_min,
470+
"tau_max": self.tau_max,
471+
}
472+
elif sampler == SamplerType.SNAS:
473+
config = {
474+
"sampler_type": sampler,
475+
"sample_frequency": self.sampler_sample_frequency,
476+
"arch_combine_fn": self.sampler_arch_combine_fn,
477+
"temp_init": self.temp_init,
478+
"temp_min": self.temp_min,
479+
"temp_annealing": self.temp_annealing,
480+
"total_epochs": self.epochs,
481+
}
482+
else:
483+
raise AttributeError(f"Illegal sampler type {sampler} provided!")
484+
485+
self.sampler_config[i] = config
486+
487+
@override
488+
def configure_sampler( # type: ignore[override]
489+
self, sampler_config_map: dict[int, dict]
490+
) -> None:
491+
"""Configures the sampler settings based on the provided configurations.
492+
493+
Args:
494+
sampler_config_map (dict[int, dict]): A dictionary where each key is an \
495+
integer representing the order of the sampler (zero-indexed), and \
496+
each value is a dictionary containing the configuration parameters \
497+
for that sampler.
498+
499+
The inner configuration dictionary can contain different sets of keys
500+
depending on the type of sampler being configured. The available keys
501+
include:
502+
503+
Generic Configurations:
504+
- sample_frequency (str): The rate at which samples should be taken.
505+
- arch_combine_fn (str): Function to combine architectures.
506+
For FairDARTS, set this to 'sigmoid'. Default is 'default'.
507+
508+
GDAS-specific Configurations:
509+
- tau_min (float): Minimum temperature for sampling.
510+
- tau_max (float): Maximum temperature for sampling.
511+
512+
SNAS-specific Configurations:
513+
- temp_init (float): Initial temperature for sampling.
514+
- temp_min (float): Minimum temperature for sampling.
515+
- temp_annealing (bool): Whether to apply temperature annealing.
516+
- total_epochs (int): Total number of training epochs.
517+
518+
The specific keys required in the dictionary depend on the type of
519+
sampler being used.
520+
Please make sure that the sample frequency of all the configurations are same.
521+
Each configuration is validated, and an error is raised if unknown keys are
522+
provided.
523+
524+
Raises:
525+
ValueError: If an unrecognized configuration key is detected.
526+
527+
Returns:
528+
None
529+
"""
530+
assert self.sampler_config is not None
531+
532+
for idx in sampler_config_map:
533+
for config_key in sampler_config_map[idx]:
534+
assert idx in self.sampler_config
535+
exists = False
536+
sampler_type = self.sampler_config[idx]["sampler_type"]
537+
if config_key in self.sampler_config[idx]:
538+
exists = True
539+
self.sampler_config[idx][config_key] = sampler_config_map[idx][
540+
config_key
541+
]
542+
assert exists, (
543+
f"{config_key} is not a valid configuration for {sampler_type}",
544+
"sampler inside composite sampler",
545+
)
546+
547+
405548
class DiscreteProfile:
406549
def __init__(
407550
self,

src/confopt/train/experiment.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141
from confopt.oneshot.archsampler import (
4242
BaseSampler,
43+
CompositeSampler,
4344
DARTSSampler,
4445
DRNASSampler,
4546
GDASSampler,
@@ -376,17 +377,36 @@ def _set_sampler(
376377
config: dict,
377378
) -> None:
378379
arch_params = self.search_space.arch_parameters
379-
self.sampler: BaseSampler | None = None
380-
if sampler == SamplerType.DARTS:
381-
self.sampler = DARTSSampler(**config, arch_parameters=arch_params)
382-
elif sampler == SamplerType.DRNAS:
383-
self.sampler = DRNASSampler(**config, arch_parameters=arch_params)
384-
elif sampler == SamplerType.GDAS:
385-
self.sampler = GDASSampler(**config, arch_parameters=arch_params)
386-
elif sampler == SamplerType.SNAS:
387-
self.sampler = SNASSampler(**config, arch_parameters=arch_params)
388-
elif sampler == SamplerType.REINMAX:
389-
self.sampler = ReinMaxSampler(**config, arch_parameters=arch_params)
380+
self.sampler: BaseSampler | CompositeSampler | None = None
381+
382+
def _get_sampler_class(sampler: SamplerType) -> Callable:
383+
if sampler == SamplerType.DARTS:
384+
return DARTSSampler
385+
if sampler == SamplerType.DRNAS:
386+
return DRNASSampler
387+
if sampler == SamplerType.GDAS:
388+
return GDASSampler
389+
if sampler == SamplerType.SNAS:
390+
return SNASSampler
391+
if sampler == SamplerType.REINMAX:
392+
return ReinMaxSampler
393+
394+
raise ValueError(f"Illegal sampler {sampler} provided")
395+
396+
if sampler == SamplerType.COMPOSITE:
397+
sub_samplers: list[BaseSampler] = []
398+
for _, sampler_config in config.items():
399+
sampler_type = sampler_config["sampler_type"]
400+
del sampler_config["sampler_type"]
401+
sampler_component = _get_sampler_class(sampler_type)(
402+
**sampler_config, arch_parameters=arch_params
403+
)
404+
sub_samplers.append(sampler_component)
405+
self.sampler = CompositeSampler(sub_samplers, arch_parameters=arch_params)
406+
else:
407+
self.sampler = _get_sampler_class(sampler)(
408+
**config, arch_parameters=arch_params
409+
)
390410

391411
def _set_perturbator(
392412
self,

src/confopt/train/search_space_handler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import torch
44

5-
from confopt.oneshot.archsampler import BaseSampler, DARTSSampler, GDASSampler
5+
from confopt.oneshot.archsampler import (
6+
BaseSampler,
7+
CompositeSampler,
8+
DARTSSampler,
9+
GDASSampler,
10+
)
611
from confopt.oneshot.dropout import Dropout
712
from confopt.oneshot.dynamic_exploration import DynamicAttentionExplorer
813
from confopt.oneshot.lora_toggler import LoRAToggler
@@ -24,7 +29,7 @@
2429
class SearchSpaceHandler:
2530
def __init__(
2631
self,
27-
sampler: BaseSampler,
32+
sampler: BaseSampler | CompositeSampler,
2833
edge_normalization: bool = False,
2934
partial_connector: PartialConnector | None = None,
3035
perturbation: BasePerturbator | None = None,

tests/test_sampler.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DRNASSampler,
99
GDASSampler,
1010
ReinMaxSampler,
11+
CompositeSampler,
1112
)
1213

1314

@@ -111,5 +112,59 @@ def test_post_sample_fn_sigmoid(self) -> None:
111112
_test_arch_combine_fn_default(sampler, alphas)
112113

113114

115+
class TestCompositeSampler(unittest.TestCase):
116+
def test_post_sample_fn_default(self) -> None:
117+
alphas = [torch.randn(14, 8), torch.randn(14, 8)]
118+
inner_samplers = [
119+
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
120+
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
121+
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
122+
]
123+
sampler = CompositeSampler(inner_samplers, alphas)
124+
_test_arch_combine_fn_default(sampler, alphas)
125+
126+
def test_post_sample_fn_sigmoid(self) -> None:
127+
alphas = [torch.randn(14, 8), torch.randn(14, 8)]
128+
inner_samplers = [
129+
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
130+
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
131+
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
132+
]
133+
sampler = CompositeSampler(inner_samplers, alphas)
134+
_test_arch_combine_fn_sigmoid(sampler, alphas)
135+
136+
def test_post_sample_fn_mixed(self) -> None:
137+
alphas = [torch.randn(14, 8), torch.randn(14, 8)]
138+
139+
# Case 1
140+
inner_samplers = [
141+
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
142+
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
143+
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
144+
]
145+
sampler = CompositeSampler(inner_samplers, alphas)
146+
# The last sampler's combine function matters,
147+
# and if its GDAS/Reinmax, it would get ignored
148+
_test_arch_combine_fn_default(sampler, alphas)
149+
150+
# Case 2
151+
inner_samplers = [
152+
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
153+
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
154+
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
155+
]
156+
sampler = CompositeSampler(inner_samplers, alphas)
157+
_test_arch_combine_fn_default(sampler, alphas)
158+
159+
# Case 3
160+
inner_samplers = [
161+
DRNASSampler(alphas, sample_frequency="epoch", arch_combine_fn="default"),
162+
GDASSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
163+
DARTSSampler(alphas, sample_frequency="epoch", arch_combine_fn="sigmoid"),
164+
]
165+
sampler = CompositeSampler(inner_samplers, alphas)
166+
_test_arch_combine_fn_sigmoid(sampler, alphas)
167+
168+
114169
if __name__ == "__main__":
115170
unittest.main()

0 commit comments

Comments
 (0)