Skip to content

Composite Sampler #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/confopt/oneshot/archsampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_sampler import BaseSampler
from .composite_sampler import CompositeSampler
from .darts.sampler import DARTSSampler
from .drnas.sampler import DRNASSampler
from .gdas.sampler import GDASSampler
Expand All @@ -12,4 +13,5 @@
"GDASSampler",
"SNASSampler",
"ReinMaxSampler",
"CompositeSampler",
]
41 changes: 41 additions & 0 deletions src/confopt/oneshot/archsampler/composite_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import torch

from confopt.oneshot.archsampler import BaseSampler
from confopt.oneshot.base import OneShotComponent


class CompositeSampler(OneShotComponent):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why inherit from OneShotComponent? Why not inherit from BaseSampler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was in 2 minds with it as well, but for me, functionality wise- I don't see it as a child of BaseSampler, since it takes a list of BaseSampler within its initialisation which i thought is a special case. So I thought a OneShotComponent is a better option for it. Let me know your opinion about it.

def __init__(
self,
arch_samplers: list[BaseSampler],
arch_parameters: list[torch.Tensor],
) -> None:
super().__init__()
self.arch_samplers = arch_samplers
self.arch_parameters = arch_parameters

# get sample frequency from the samplers
self.sample_frequency = arch_samplers[0].sample_frequency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still use sampling frequency in the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we never change it to epoch in our experiments, but we have option to do that. I'm not sure how handy this can come later, but I think it doesn't hurt to leave it there

for sampler in arch_samplers:
assert (
self.sample_frequency == sampler.sample_frequency
), "All the sampler must have the same sample frequency"

def sample(self, alpha: torch.Tensor) -> torch.Tensor:
sampled_alphas = alpha
for sampler in self.arch_samplers:
sampled_alphas = sampler.sample(sampled_alphas)

return sampled_alphas

def new_epoch(self) -> None:
super().new_epoch()
for sampler in self.arch_samplers:
sampler.new_epoch()

def new_step(self) -> None:
super().new_step()
for sampler in self.arch_samplers:
sampler.new_step()