Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions src/flowMC/Sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import jax.numpy as jnp
import tqdm
from jaxtyping import Array, Float, PRNGKeyArray
from typing import Optional

from flowMC.strategy.base import Strategy
from flowMC.resource.base import Resource
from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle
from flowMC.strategy.base import Strategy


class Sampler:
Expand All @@ -29,7 +31,7 @@ class Sampler:
rng_key: PRNGKeyArray
resources: dict[str, Resource]
strategies: dict[str, Strategy]
strategy_order: Optional[list[str]]
strategy_order: Optional[list[tuple[str, str]]]

# Logging hyperparameters
verbose: bool = False
Expand All @@ -43,7 +45,7 @@ def __init__(
rng_key: PRNGKeyArray,
resources: None | dict[str, Resource] = None,
strategies: None | dict[str, Strategy] = None,
strategy_order: None | list[str] = None,
strategy_order: None | list[tuple[str, str]] = None,
resource_strategy_bundles: None | ResourceStrategyBundle = None,
**kwargs,
):
Expand Down Expand Up @@ -93,19 +95,35 @@ def sample(self, initial_position: Float[Array, "n_chains n_dim"], data: dict):
rng_key = self.rng_key
last_step = initial_position
assert isinstance(self.strategy_order, list)
for strategy in self.strategy_order:

if not all(
isinstance(item, (tuple, list)) and len(item) == 2
for item in self.strategy_order
):
raise TypeError(
"strategy_order must be a list of 2-tuples: (strategy_name, phase)."
)

for strategy, _ in self.strategy_order:
if strategy not in self.strategies:
raise ValueError(
f"Invalid strategy name '{strategy}' provided. "
f"Available strategies are: {list(self.strategies.keys())}."
)
(
rng_key,
self.resources,
last_step,
) = self.strategies[
strategy
](rng_key, self.resources, last_step, data)

with tqdm.tqdm(
self.strategy_order,
total=len(self.strategy_order),
) as strategy_order_bar:
for strategy, phase in strategy_order_bar:
strategy_order_bar.set_description(phase)
(
rng_key,
self.resources,
last_step,
) = self.strategies[
strategy
](rng_key, self.resources, last_step, data)

# TODO: Implement quick access and summary functions that operates on buffer

Expand Down
37 changes: 20 additions & 17 deletions src/flowMC/resource_strategy_bundle/RQSpline_MALA.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from typing import Callable

import equinox as eqx
import jax
from jaxtyping import Array, Float, PRNGKeyArray
import equinox as eqx

from flowMC.resource.base import Resource
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.kernel.MALA import MALA
from flowMC.resource.kernel.NF_proposal import NFProposal
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline
from flowMC.resource.optimizer import Optimizer
from flowMC.resource.states import State
from flowMC.resource_strategy_bundle.base import Phase, ResourceStrategyBundle
from flowMC.strategy.lambda_function import Lambda
from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps
from flowMC.strategy.take_steps import TakeGroupSteps, TakeSerialSteps
from flowMC.strategy.train_model import TrainModel
from flowMC.strategy.update_state import UpdateState
from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle


class RQSpline_MALA_Bundle(ResourceStrategyBundle):
Expand Down Expand Up @@ -249,25 +249,28 @@ def update_model(
}

training_phase = [
"local_stepper",
"update_global_step",
"model_trainer",
"update_model",
"global_stepper",
"update_local_step",
("local_stepper", Phase.TRAINING.value),
("update_global_step", Phase.TRAINING.value),
("model_trainer", Phase.TRAINING.value),
("update_model", Phase.TRAINING.value),
("global_stepper", Phase.TRAINING.value),
("update_local_step", Phase.TRAINING.value),
]
production_phase = [
"local_stepper",
"update_global_step",
"global_stepper",
"update_local_step",
("local_stepper", Phase.PRODUCTION.value),
("update_global_step", Phase.PRODUCTION.value),
("global_stepper", Phase.PRODUCTION.value),
("update_local_step", Phase.PRODUCTION.value),
]

strategy_order = []

for _ in range(n_training_loops):
strategy_order.extend(training_phase)

strategy_order.append("reset_steppers")
strategy_order.append("update_state")
strategy_order.append(("reset_steppers", Phase.INTERMEDIATE.value))
strategy_order.append(("update_state", Phase.INTERMEDIATE.value))

for _ in range(n_production_loops):
strategy_order.extend(production_phase)

Expand Down
46 changes: 24 additions & 22 deletions src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
from typing import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray
import equinox as eqx

from flowMC.resource.base import Resource
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.logPDF import LogPDF, TemperedPDF
from flowMC.resource.kernel.MALA import MALA
from flowMC.resource.kernel.NF_proposal import NFProposal
from flowMC.resource.logPDF import LogPDF, TemperedPDF
from flowMC.resource.model.nf_model.rqSpline import MaskedCouplingRQSpline
from flowMC.resource.optimizer import Optimizer
from flowMC.resource.states import State
from flowMC.resource_strategy_bundle.base import Phase, ResourceStrategyBundle
from flowMC.strategy.lambda_function import Lambda
from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps
from flowMC.strategy.parallel_tempering import ParallelTempering
from flowMC.strategy.take_steps import TakeGroupSteps, TakeSerialSteps
from flowMC.strategy.train_model import TrainModel
from flowMC.strategy.update_state import UpdateState
from flowMC.strategy.parallel_tempering import ParallelTempering

from flowMC.resource_strategy_bundle.base import ResourceStrategyBundle


class RQSpline_MALA_PT_Bundle(ResourceStrategyBundle):
Expand Down Expand Up @@ -317,27 +316,30 @@ def initialize_tempered_positions(
}

training_phase = [
"parallel_tempering",
"local_stepper",
"update_global_step",
"model_trainer",
"update_model",
"global_stepper",
"update_local_step",
("parallel_tempering", Phase.TRAINING.value),
("local_stepper", Phase.TRAINING.value),
("update_global_step", Phase.TRAINING.value),
("model_trainer", Phase.TRAINING.value),
("update_model", Phase.TRAINING.value),
("global_stepper", Phase.TRAINING.value),
("update_local_step", Phase.TRAINING.value),
]
production_phase = [
"parallel_tempering",
"local_stepper",
"update_global_step",
"global_stepper",
"update_local_step",
("parallel_tempering", Phase.PRODUCTION.value),
("local_stepper", Phase.PRODUCTION.value),
("update_global_step", Phase.PRODUCTION.value),
("global_stepper", Phase.PRODUCTION.value),
("update_local_step", Phase.PRODUCTION.value),
]
strategy_order = ["initialize_tempered_positions"]

strategy_order = [("initialize_tempered_positions", Phase.INITIALIZATION.value)]

for _ in range(n_training_loops):
strategy_order.extend(training_phase)

strategy_order.append("reset_steppers")
strategy_order.append("update_state")
strategy_order.append(("reset_steppers", Phase.INTERMEDIATE.value))
strategy_order.append(("update_state", Phase.INTERMEDIATE.value))

for _ in range(n_production_loops):
strategy_order.extend(production_phase)

Expand Down
12 changes: 11 additions & 1 deletion src/flowMC/resource_strategy_bundle/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import enum
from abc import ABC

from flowMC.resource.base import Resource
from flowMC.strategy.base import Strategy


class Phase(enum.Enum):
"""Enumeration for the different phases of the resource-strategy bundle."""

INITIALIZATION = "initialization"
INTERMEDIATE = "intermediate"
PRODUCTION = "production"
TRAINING = "training"


class ResourceStrategyBundle(ABC):
"""Resource-Strategy Bundle is aim to be the highest level of abstraction in the
flowMC library.
Expand All @@ -14,4 +24,4 @@ class ResourceStrategyBundle(ABC):

resources: dict[str, Resource]
strategies: dict[str, Strategy]
strategy_order: list[str]
strategy_order: list[tuple[str, str]]
8 changes: 4 additions & 4 deletions test/integration/test_HMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from jax.scipy.special import logsumexp
from jaxtyping import Array, Float

from flowMC.resource.kernel.HMC import HMC
from flowMC.strategy.take_steps import TakeSerialSteps
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.kernel.HMC import HMC
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.states import State
from flowMC.Sampler import Sampler
from flowMC.strategy.take_steps import TakeSerialSteps


def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
Expand Down Expand Up @@ -82,7 +82,7 @@ def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
rng_key=rng_key,
resources=resource,
strategies={"take_steps": strategy},
strategy_order=["take_steps"],
strategy_order=[("take_steps", "testing_phase")],
)

nf_sampler.sample(initial_position, data)
8 changes: 4 additions & 4 deletions test/integration/test_MALA.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from jax.scipy.special import logsumexp
from jaxtyping import Array, Float

from flowMC.resource.kernel.MALA import MALA
from flowMC.strategy.take_steps import TakeSerialSteps
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.kernel.MALA import MALA
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.states import State
from flowMC.Sampler import Sampler
from flowMC.strategy.take_steps import TakeSerialSteps


def dual_moon_pe(x: Float[Array, "n_dims"], data: dict):
Expand Down Expand Up @@ -74,7 +74,7 @@ def dual_moon_pe(x: Float[Array, "n_dims"], data: dict):
rng_key=rng_key,
resources=resource,
strategies={"take_steps": strategy},
strategy_order=["take_steps"],
strategy_order=[("take_steps", "testing_phase")],
)

nf_sampler.sample(initial_position, data)
8 changes: 4 additions & 4 deletions test/integration/test_RWMCMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from jax.scipy.special import logsumexp
from jaxtyping import Array, Float

from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk
from flowMC.strategy.take_steps import TakeSerialSteps
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.kernel.Gaussian_random_walk import GaussianRandomWalk
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.states import State
from flowMC.Sampler import Sampler
from flowMC.strategy.take_steps import TakeSerialSteps


def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
Expand Down Expand Up @@ -81,7 +81,7 @@ def dual_moon_pe(x: Float[Array, " n_dims"], data: dict):
rng_key=rng_key,
resources=resource,
strategies={"take_steps": strategy},
strategy_order=["take_steps"],
strategy_order=[("take_steps", "testing_phase")],
)

nf_sampler.sample(initial_position, data)
Loading