|  | 
| 3 | 3 | from abc import ABC | 
| 4 | 4 | from typing import Any | 
| 5 | 5 | 
 | 
|  | 6 | +from typing_extensions import override | 
|  | 7 | + | 
| 6 | 8 | from confopt.enums import SamplerType, SearchSpaceType | 
| 7 | 9 | from confopt.searchspace.darts.core.genotypes import DARTSGenotype | 
| 8 | 10 | from confopt.utils import get_num_classes | 
| @@ -402,6 +404,147 @@ def configure_sampler(self, **kwargs) -> None:  # type: ignore | 
| 402 | 404 |         super().configure_sampler(**kwargs) | 
| 403 | 405 | 
 | 
| 404 | 406 | 
 | 
|  | 407 | +class CompositeProfile(BaseProfile, ABC): | 
|  | 408 | +    SAMPLER_TYPE = SamplerType.COMPOSITE | 
|  | 409 | + | 
|  | 410 | +    def __init__( | 
|  | 411 | +        self, | 
|  | 412 | +        searchspace_type: str | SearchSpaceType, | 
|  | 413 | +        samplers: list[str | SamplerType], | 
|  | 414 | +        epochs: int, | 
|  | 415 | +        # GDAS configs | 
|  | 416 | +        tau_min: float = 0.1, | 
|  | 417 | +        tau_max: float = 10, | 
|  | 418 | +        # SNAS configs | 
|  | 419 | +        temp_init: float = 1.0, | 
|  | 420 | +        temp_min: float = 0.03, | 
|  | 421 | +        temp_annealing: bool = True, | 
|  | 422 | +        **kwargs: Any, | 
|  | 423 | +    ) -> None: | 
|  | 424 | +        self.samplers = [] | 
|  | 425 | +        for sampler in samplers: | 
|  | 426 | +            if isinstance(sampler, str): | 
|  | 427 | +                sampler = SamplerType(sampler) | 
|  | 428 | +            self.samplers.append(sampler) | 
|  | 429 | + | 
|  | 430 | +        self.tau_min = tau_min | 
|  | 431 | +        self.tau_max = tau_max | 
|  | 432 | + | 
|  | 433 | +        self.temp_init = temp_init | 
|  | 434 | +        self.temp_min = temp_min | 
|  | 435 | +        self.temp_annealing = temp_annealing | 
|  | 436 | + | 
|  | 437 | +        super().__init__(  # type: ignore | 
|  | 438 | +            self.SAMPLER_TYPE, | 
|  | 439 | +            searchspace_type, | 
|  | 440 | +            epochs, | 
|  | 441 | +            **kwargs, | 
|  | 442 | +        ) | 
|  | 443 | + | 
|  | 444 | +    def _initialize_sampler_config(self) -> None: | 
|  | 445 | +        """Initializes the sampler configuration for Composite samplers. | 
|  | 446 | +
 | 
|  | 447 | +        The sampler configuration includes the sample frequency and the architecture | 
|  | 448 | +        combine function. | 
|  | 449 | +
 | 
|  | 450 | +        Args: | 
|  | 451 | +            None | 
|  | 452 | +
 | 
|  | 453 | +        Returns: | 
|  | 454 | +            None | 
|  | 455 | +        """ | 
|  | 456 | +        self.sampler_config: dict[int, dict] = {}  # type: ignore | 
|  | 457 | +        for i, sampler in enumerate(self.samplers): | 
|  | 458 | +            if sampler in [SamplerType.DARTS, SamplerType.DRNAS]: | 
|  | 459 | +                config = { | 
|  | 460 | +                    "sampler_type": sampler, | 
|  | 461 | +                    "sample_frequency": self.sampler_sample_frequency, | 
|  | 462 | +                    "arch_combine_fn": self.sampler_arch_combine_fn, | 
|  | 463 | +                } | 
|  | 464 | +            elif sampler in [SamplerType.GDAS, SamplerType.REINMAX]: | 
|  | 465 | +                config = { | 
|  | 466 | +                    "sampler_type": sampler, | 
|  | 467 | +                    "sample_frequency": self.sampler_sample_frequency, | 
|  | 468 | +                    "arch_combine_fn": self.sampler_arch_combine_fn, | 
|  | 469 | +                    "tau_min": self.tau_min, | 
|  | 470 | +                    "tau_max": self.tau_max, | 
|  | 471 | +                } | 
|  | 472 | +            elif sampler == SamplerType.SNAS: | 
|  | 473 | +                config = { | 
|  | 474 | +                    "sampler_type": sampler, | 
|  | 475 | +                    "sample_frequency": self.sampler_sample_frequency, | 
|  | 476 | +                    "arch_combine_fn": self.sampler_arch_combine_fn, | 
|  | 477 | +                    "temp_init": self.temp_init, | 
|  | 478 | +                    "temp_min": self.temp_min, | 
|  | 479 | +                    "temp_annealing": self.temp_annealing, | 
|  | 480 | +                    "total_epochs": self.epochs, | 
|  | 481 | +                } | 
|  | 482 | +            else: | 
|  | 483 | +                raise AttributeError(f"Illegal sampler type {sampler} provided!") | 
|  | 484 | + | 
|  | 485 | +            self.sampler_config[i] = config | 
|  | 486 | + | 
|  | 487 | +    @override | 
|  | 488 | +    def configure_sampler(  # type: ignore[override] | 
|  | 489 | +        self, sampler_config_map: dict[int, dict] | 
|  | 490 | +    ) -> None: | 
|  | 491 | +        """Configures the sampler settings based on the provided configurations. | 
|  | 492 | +
 | 
|  | 493 | +        Args: | 
|  | 494 | +            sampler_config_map (dict[int, dict]): A dictionary where each key is an \ | 
|  | 495 | +                integer representing the order of the sampler (zero-indexed), and \ | 
|  | 496 | +                each value is a dictionary containing the configuration parameters \ | 
|  | 497 | +                for that sampler. | 
|  | 498 | +
 | 
|  | 499 | +            The inner configuration dictionary can contain different sets of keys | 
|  | 500 | +            depending on the type of sampler being configured. The available keys | 
|  | 501 | +            include: | 
|  | 502 | +
 | 
|  | 503 | +                Generic Configurations: | 
|  | 504 | +                    - sample_frequency (str): The rate at which samples should be taken. | 
|  | 505 | +                    - arch_combine_fn (str): Function to combine architectures. | 
|  | 506 | +                      For FairDARTS, set this to 'sigmoid'. Default is 'default'. | 
|  | 507 | +
 | 
|  | 508 | +                GDAS-specific Configurations: | 
|  | 509 | +                    - tau_min (float): Minimum temperature for sampling. | 
|  | 510 | +                    - tau_max (float): Maximum temperature for sampling. | 
|  | 511 | +
 | 
|  | 512 | +                SNAS-specific Configurations: | 
|  | 513 | +                    - temp_init (float): Initial temperature for sampling. | 
|  | 514 | +                    - temp_min (float): Minimum temperature for sampling. | 
|  | 515 | +                    - temp_annealing (bool): Whether to apply temperature annealing. | 
|  | 516 | +                    - total_epochs (int): Total number of training epochs. | 
|  | 517 | +
 | 
|  | 518 | +        The specific keys required in the dictionary depend on the type of | 
|  | 519 | +        sampler being used. | 
|  | 520 | +        Please make sure that the sample frequency of all the configurations are same. | 
|  | 521 | +        Each configuration is validated, and an error is raised if unknown keys are | 
|  | 522 | +        provided. | 
|  | 523 | +
 | 
|  | 524 | +        Raises: | 
|  | 525 | +            ValueError: If an unrecognized configuration key is detected. | 
|  | 526 | +
 | 
|  | 527 | +        Returns: | 
|  | 528 | +            None | 
|  | 529 | +        """ | 
|  | 530 | +        assert self.sampler_config is not None | 
|  | 531 | + | 
|  | 532 | +        for idx in sampler_config_map: | 
|  | 533 | +            for config_key in sampler_config_map[idx]: | 
|  | 534 | +                assert idx in self.sampler_config | 
|  | 535 | +                exists = False | 
|  | 536 | +                sampler_type = self.sampler_config[idx]["sampler_type"] | 
|  | 537 | +                if config_key in self.sampler_config[idx]: | 
|  | 538 | +                    exists = True | 
|  | 539 | +                    self.sampler_config[idx][config_key] = sampler_config_map[idx][ | 
|  | 540 | +                        config_key | 
|  | 541 | +                    ] | 
|  | 542 | +                assert exists, ( | 
|  | 543 | +                    f"{config_key} is not a valid configuration for {sampler_type}", | 
|  | 544 | +                    "sampler inside composite sampler", | 
|  | 545 | +                ) | 
|  | 546 | + | 
|  | 547 | + | 
| 405 | 548 | class DiscreteProfile: | 
| 406 | 549 |     def __init__( | 
| 407 | 550 |         self, | 
|  | 
0 commit comments