diff --git a/src/flowMC/Sampler.py b/src/flowMC/Sampler.py index add96f2..d27d701 100644 --- a/src/flowMC/Sampler.py +++ b/src/flowMC/Sampler.py @@ -1,10 +1,12 @@ +from typing import Optional + import jax.numpy as jnp +import tqdm from jaxtyping import Array, Float, PRNGKeyArray -from typing import Optional -from flowMC.strategy.base import Strategy from flowMC.resource.base import Resource from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle +from flowMC.strategy.base import Strategy class Sampler: @@ -29,7 +31,7 @@ class Sampler: rng_key: PRNGKeyArray resources: dict[str, Resource] strategies: dict[str, Strategy] - strategy_order: Optional[list[str]] + strategy_order: Optional[list[tuple[str, str]]] # Logging hyperparameters verbose: bool = False @@ -43,7 +45,7 @@ def __init__( rng_key: PRNGKeyArray, resources: None | dict[str, Resource] = None, strategies: None | dict[str, Strategy] = None, - strategy_order: None | list[str] = None, + strategy_order: None | list[tuple[str, str]] = None, resource_strategy_bundles: None | ResourceStrategyBundle = None, **kwargs, ): @@ -93,19 +95,35 @@ def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict): rng_key = self.rng_key last_step = initial_position assert isinstance(self.strategy_order, list) - for strategy in self.strategy_order: + + if not all( + isinstance(item, (tuple, list)) and len(item) == 2 + for item in self.strategy_order + ): + raise TypeError( + "strategy_order must be a list of 2-tuples: (strategy_name, phase)." + ) + + for strategy, _ in self.strategy_order: if strategy not in self.strategies: raise ValueError( f"Invalid strategy name '{strategy}' provided. " f"Available strategies are: {list(self.strategies.keys())}." ) - ( - rng_key, - self.resources, - last_step, - ) = self.strategies[ - strategy - ](rng_key, self.resources, last_step, data) + + with tqdm.tqdm( + self.strategy_order, + total=len(self.strategy_order), + ) as strategy_order_bar: + for strategy, phase in strategy_order_bar: + strategy_order_bar.set_description(phase) + ( + rng_key, + self.resources, + last_step, + ) = self.strategies[ + strategy + ](rng_key, self.resources, last_step, data) # TODO: Implement quick access and summary functions that operates on buffer diff --git a/src/flowMC/resource_strategy_bundle/RQSpline_MALA.py b/src/flowMC/resource_strategy_bundle/RQSpline_MALA.py index 0af672b..42eef6d 100644 --- a/src/flowMC/resource_strategy_bundle/RQSpline_MALA.py +++ b/src/flowMC/resource_strategy_bundle/RQSpline_MALA.py @@ -1,22 +1,22 @@ from typing import Callable +import equinox as eqx import jax from jaxtyping import Array, Float, PRNGKeyArray -import equinox as eqx from flowMC.resource.base import Resource from flowMC.resource.buffers import Buffer -from flowMC.resource.states import State -from flowMC.resource.logPDF import LogPDF from flowMC.resource.kernel.MALA import MALA from flowMC.resource.kernel.NF_proposal import NFProposal +from flowMC.resource.logPDF import LogPDF from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline from flowMC.resource.optimizer import Optimizer +from flowMC.resource.states import State +from flowMC.resource_strategy_bundle.base import Phase, ResourceStrategyBundle from flowMC.strategy.lambda_function import Lambda -from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps +from flowMC.strategy.take_steps import TakeGroupSteps, TakeSerialSteps from flowMC.strategy.train_model import TrainModel from flowMC.strategy.update_state import UpdateState -from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle class RQSpline_MALA_Bundle(ResourceStrategyBundle): @@ -249,25 +249,28 @@ def update_model( } training_phase = [ - "local_stepper", - "update_global_step", - "model_trainer", - "update_model", - "global_stepper", - "update_local_step", + ("local_stepper", Phase.TRAINING.value), + ("update_global_step", Phase.TRAINING.value), + ("model_trainer", Phase.TRAINING.value), + ("update_model", Phase.TRAINING.value), + ("global_stepper", Phase.TRAINING.value), + ("update_local_step", Phase.TRAINING.value), ] production_phase = [ - "local_stepper", - "update_global_step", - "global_stepper", - "update_local_step", + ("local_stepper", Phase.PRODUCTION.value), + ("update_global_step", Phase.PRODUCTION.value), + ("global_stepper", Phase.PRODUCTION.value), + ("update_local_step", Phase.PRODUCTION.value), ] + strategy_order = [] + for _ in range(n_training_loops): strategy_order.extend(training_phase) - strategy_order.append("reset_steppers") - strategy_order.append("update_state") + strategy_order.append(("reset_steppers", Phase.INTERMEDIATE.value)) + strategy_order.append(("update_state", Phase.INTERMEDIATE.value)) + for _ in range(n_production_loops): strategy_order.extend(production_phase) diff --git a/src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py b/src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py index af97302..8b2414e 100644 --- a/src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py +++ b/src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py @@ -1,25 +1,24 @@ from typing import Callable +import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import Array, Float, PRNGKeyArray -import equinox as eqx from flowMC.resource.base import Resource from flowMC.resource.buffers import Buffer -from flowMC.resource.states import State -from flowMC.resource.logPDF import LogPDF, TemperedPDF from flowMC.resource.kernel.MALA import MALA from flowMC.resource.kernel.NF_proposal import NFProposal +from flowMC.resource.logPDF import LogPDF, TemperedPDF from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline from flowMC.resource.optimizer import Optimizer +from flowMC.resource.states import State +from flowMC.resource_strategy_bundle.base import Phase, ResourceStrategyBundle from flowMC.strategy.lambda_function import Lambda -from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps +from flowMC.strategy.parallel_tempering import ParallelTempering +from flowMC.strategy.take_steps import TakeGroupSteps, TakeSerialSteps from flowMC.strategy.train_model import TrainModel from flowMC.strategy.update_state import UpdateState -from flowMC.strategy.parallel_tempering import ParallelTempering - -from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle class RQSpline_MALA_PT_Bundle(ResourceStrategyBundle): @@ -317,27 +316,30 @@ def initialize_tempered_positions( } training_phase = [ - "parallel_tempering", - "local_stepper", - "update_global_step", - "model_trainer", - "update_model", - "global_stepper", - "update_local_step", + ("parallel_tempering", Phase.TRAINING.value), + ("local_stepper", Phase.TRAINING.value), + ("update_global_step", Phase.TRAINING.value), + ("model_trainer", Phase.TRAINING.value), + ("update_model", Phase.TRAINING.value), + ("global_stepper", Phase.TRAINING.value), + ("update_local_step", Phase.TRAINING.value), ] production_phase = [ - "parallel_tempering", - "local_stepper", - "update_global_step", - "global_stepper", - "update_local_step", + ("parallel_tempering", Phase.PRODUCTION.value), + ("local_stepper", Phase.PRODUCTION.value), + ("update_global_step", Phase.PRODUCTION.value), + ("global_stepper", Phase.PRODUCTION.value), + ("update_local_step", Phase.PRODUCTION.value), ] - strategy_order = ["initialize_tempered_positions"] + + strategy_order = [("initialize_tempered_positions", Phase.INITIALIZATION.value)] + for _ in range(n_training_loops): strategy_order.extend(training_phase) - strategy_order.append("reset_steppers") - strategy_order.append("update_state") + strategy_order.append(("reset_steppers", Phase.INTERMEDIATE.value)) + strategy_order.append(("update_state", Phase.INTERMEDIATE.value)) + for _ in range(n_production_loops): strategy_order.extend(production_phase) diff --git a/src/flowMC/resource_strategy_bundle/base.py b/src/flowMC/resource_strategy_bundle/base.py index edce9f1..f223261 100644 --- a/src/flowMC/resource_strategy_bundle/base.py +++ b/src/flowMC/resource_strategy_bundle/base.py @@ -1,9 +1,19 @@ +import enum from abc import ABC from flowMC.resource.base import Resource from flowMC.strategy.base import Strategy +class Phase(enum.Enum): + """Enumeration for the different phases of the resource-strategy bundle.""" + + INITIALIZATION = "initialization" + INTERMEDIATE = "intermediate" + PRODUCTION = "production" + TRAINING = "training" + + class ResourceStrategyBundle(ABC): """Resource-Strategy Bundle is aim to be the highest level of abstraction in the flowMC library. @@ -14,4 +24,4 @@ class ResourceStrategyBundle(ABC): resources: dict[str, Resource] strategies: dict[str, Strategy] - strategy_order: list[str] + strategy_order: list[tuple[str, str]] diff --git a/test/integration/test_HMC.py b/test/integration/test_HMC.py index 5bd2723..7e94c80 100644 --- a/test/integration/test_HMC.py +++ b/test/integration/test_HMC.py @@ -3,12 +3,12 @@ from jax.scipy.special import logsumexp from jaxtyping import Array, Float -from flowMC.resource.kernel.HMC import HMC -from flowMC.strategy.take_steps import TakeSerialSteps from flowMC.resource.buffers import Buffer -from flowMC.resource.states import State +from flowMC.resource.kernel.HMC import HMC from flowMC.resource.logPDF import LogPDF +from flowMC.resource.states import State from flowMC.Sampler import Sampler +from flowMC.strategy.take_steps import TakeSerialSteps def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): @@ -82,7 +82,7 @@ def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): rng_key=rng_key, resources=resource, strategies={"take_steps": strategy}, - strategy_order=["take_steps"], + strategy_order=[("take_steps", "testing_phase")], ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_MALA.py b/test/integration/test_MALA.py index 2242437..0640e17 100644 --- a/test/integration/test_MALA.py +++ b/test/integration/test_MALA.py @@ -3,12 +3,12 @@ from jax.scipy.special import logsumexp from jaxtyping import Array, Float -from flowMC.resource.kernel.MALA import MALA -from flowMC.strategy.take_steps import TakeSerialSteps from flowMC.resource.buffers import Buffer -from flowMC.resource.states import State +from flowMC.resource.kernel.MALA import MALA from flowMC.resource.logPDF import LogPDF +from flowMC.resource.states import State from flowMC.Sampler import Sampler +from flowMC.strategy.take_steps import TakeSerialSteps def dual_moon_pe(x: Float[Array, "n_dims"], data: dict): @@ -74,7 +74,7 @@ def dual_moon_pe(x: Float[Array, "n_dims"], data: dict): rng_key=rng_key, resources=resource, strategies={"take_steps": strategy}, - strategy_order=["take_steps"], + strategy_order=[("take_steps", "testing_phase")], ) nf_sampler.sample(initial_position, data) diff --git a/test/integration/test_RWMCMC.py b/test/integration/test_RWMCMC.py index dccee84..f523dc1 100644 --- a/test/integration/test_RWMCMC.py +++ b/test/integration/test_RWMCMC.py @@ -3,12 +3,12 @@ from jax.scipy.special import logsumexp from jaxtyping import Array, Float -from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk -from flowMC.strategy.take_steps import TakeSerialSteps from flowMC.resource.buffers import Buffer -from flowMC.resource.states import State +from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk from flowMC.resource.logPDF import LogPDF +from flowMC.resource.states import State from flowMC.Sampler import Sampler +from flowMC.strategy.take_steps import TakeSerialSteps def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): @@ -81,7 +81,7 @@ def dual_moon_pe(x: Float[Array, " n_dims"], data: dict): rng_key=rng_key, resources=resource, strategies={"take_steps": strategy}, - strategy_order=["take_steps"], + strategy_order=[("take_steps", "testing_phase")], ) nf_sampler.sample(initial_position, data)