diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py index c8b647e2e..d3c5d05a3 100644 --- a/sbi/inference/trainers/npse/npse.py +++ b/sbi/inference/trainers/npse/npse.py @@ -13,9 +13,7 @@ from sbi import utils as utils from sbi.inference import NeuralInference -from sbi.inference.posteriors import ( - DirectPosterior, -) +from sbi.inference.posteriors import DirectPosterior from sbi.inference.posteriors.score_posterior import ScorePosterior from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator from sbi.neural_nets.factory import posterior_score_nn @@ -304,9 +302,14 @@ def default_calibration_kernel(x): self._neural_net.to(self._device) if isinstance(validation_times, int): - validation_times = torch.linspace( - self._neural_net.t_min, self._neural_net.t_max, validation_times - ) + if hasattr(self._neural_net, "times_schedule"): + validation_times = self._neural_net.times_schedule(validation_times) + else: + validation_times = torch.linspace( + self._neural_net.t_min, + self._neural_net.t_max, + steps=validation_times, + ) assert isinstance( validation_times, Tensor ) # let pyright know validation_times is a Tensor. diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 132d44d1d..acacc7d67 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -2,13 +2,13 @@ # under the Apache License Version 2.0, see import math -from math import pi +from math import exp, pi from typing import Callable, Optional, Union import torch -from torch import Tensor, nn - from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator +from scipy import stats +from torch import Tensor, nn class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): @@ -42,6 +42,8 @@ def __init__( input_shape: torch.Size, condition_shape: torch.Size, weight_fn: Union[str, Callable] = "max_likelihood", + beta_min: float = 0.01, + beta_max: float = 10.0, mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-3, @@ -59,6 +61,12 @@ def __init__( - "identity": constant weights (1.), - "max_likelihood": weights proportional to the diffusion function, or - a custom function that returns a Callable. + beta_min: starting/minimal value for beta, the variance of the noise + beta_max: ending/maximal value for beta, the variance of the noise + mean_0: expected value of 0th step noise + scale_0: sigma of 0th step noise + t_min: smallest time step + t_max: largest time step """ super().__init__(net, input_shape, condition_shape) @@ -66,7 +74,14 @@ def __init__( # Set lambdas (variance weights) function. self._set_weight_fn(weight_fn) - # Min time for diffusion (0 can be numerically unstable). + # Store device of this module + self.device = net.device if hasattr(net, "device") else torch.device("cpu") + + # Min/max values for noise variance beta + self.beta_min = beta_min + self.beta_max = beta_max + + # Min/max values for limits to time self.t_min = t_min self.t_max = t_max @@ -152,6 +167,7 @@ def loss( input: Input variable i.e. theta. condition: Conditioning variable. times: SDE time variable in [t_min, t_max]. Uniformly sampled if None. + if None, will be filled by calling the time_schedule method control_variate: Whether to use a control variate to reduce the variance of the stochastic loss estimator. control_variate_threshold: Threshold for the control variate. If the std @@ -161,13 +177,12 @@ def loss( MSE between target score and network output, scaled by the weight function. """ - # Sample diffusion times. + # update device if required + self.device = input.device if self.device != input.device else self.device + + # Sample times from the Markov chain, use batch dimension if times is None: - times = ( - torch.rand(input.shape[0], device=input.device) - * (self.t_max - self.t_min) - + self.t_min - ) + times = self.times_schedule(input.shape[0]) # Sample noise. eps = torch.randn_like(input) @@ -312,6 +327,58 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ raise NotImplementedError + def noise_schedule(self, times: Tensor) -> Tensor: + """ + Generate a beta schedule in stochastic differential equations (SDEs). + This will be used for sampling. + + This method acts as a fallback in case derivative classes do not + implement it on their own. It calculates a linear beta schedule defined + by the input `times`, which represent the normalized time steps t ∈ [0, 1]. + + We implement a linear noise schedule here. + + Args: + times (Tensor): + SDE times in [0, 1]. This tensor will be regenerated from + self.times_schedule + + Returns: + Tensor: Generated beta schedule at a given time. + + """ + return self.beta_min + (self.beta_max - self.beta_min) * times + + def times_schedule( + self, num_samples: int, t_min: float = None, t_max: float = None + ) -> Tensor: + """ + Time samples for evaluating the diffusion model. + Perform uniform sampling of time variables within the range [t_min, t_max]. + The `times` tensor will be put on the same device as the stored network. + + We implement a uniformly sampled time stepping here. + + Args: + num_samples (int): Number of samples to generate. + t_min (float, optional): The minimum time value. Defaults to self.t_min. + t_max (float, optional): The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of sampled time variables scaled and shifted to the + range [0,1]. + + """ + t_min = self.t_min if isinstance(t_min, type(None)) else t_min + t_max = self.t_max if isinstance(t_max, type(None)) else t_max + + times = torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min + + # t_min and t_max need to be part of the sequence + times[0,...] = t_min + times[-1,...] = t_max + return torch.Tensor(sorted(times)) + def _set_weight_fn(self, weight_fn: Union[str, Callable]): """Set the weight function. @@ -353,10 +420,9 @@ def __init__( mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, - t_max: float = 1.0, + t_max: float = 1.0 ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max + super().__init__( net, input_shape, @@ -364,6 +430,8 @@ def __init__( mean_0=mean_0, std_0=std_0, weight_fn=weight_fn, + beta_min=beta_min, + beta_max=beta_max, t_min=t_min, t_max=t_max, ) @@ -399,17 +467,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return torch.sqrt(std) - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in variance preserving SDEs. - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance preserving SDEs. @@ -420,7 +477,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) return phi * input @@ -435,7 +492,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - g = torch.sqrt(self._beta_schedule(times)) + g = torch.sqrt(self.noise_schedule(times)) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) return g @@ -457,13 +514,13 @@ def __init__( t_min: float = 1e-2, t_max: float = 1.0, ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max super().__init__( net, input_shape, condition_shape, weight_fn=weight_fn, + beta_min=beta_min, + beta_max=beta_max, mean_0=mean_0, std_0=std_0, t_min=t_min, @@ -501,18 +558,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in sub-variance preserving SDEs. - (Same as for variance preserving SDEs.) - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for sub-variance preserving SDEs. @@ -523,7 +568,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) @@ -542,7 +587,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ g = torch.sqrt( torch.abs( - self._beta_schedule(times) + self.noise_schedule(times) * ( 1 - torch.exp( @@ -612,17 +657,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _sigma_schedule(self, times: Tensor) -> Tensor: - """Geometric sigma schedule for variance exploding SDEs. - - Args: - times: SDE time variable in [0,1]. - - Returns: - Sigma schedule at a given time. - """ - return self.sigma_min * (self.sigma_max / self.sigma_min) ** times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance exploding SDEs. @@ -645,9 +679,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Diffusion function at a given time. """ - g = self._sigma_schedule(times) * math.sqrt( - (2 * math.log(self.sigma_max / self.sigma_min)) - ) + sigma_scale = self.sigma_max / self.sigma_min + sigmas = self.sigma_min * (sigma_scale) ** times + g = sigmas * math.sqrt((2 * math.log(sigma_scale))) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index 274c0b2c7..5780163c7 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -2,17 +2,13 @@ import torch import torch.nn as nn -from torch import Tensor - from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - GaussianFourierTimeEmbedding, - SubVPScoreEstimator, - VEScoreEstimator, - VPScoreEstimator, -) -from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization + ConditionalScoreEstimator, GaussianFourierTimeEmbedding, + SubVPScoreEstimator, VEScoreEstimator, VPScoreEstimator) +from sbi.utils.sbiutils import (standardizing_net, z_score_parser, + z_standardization) from sbi.utils.user_input_checks import check_data_device +from torch import Tensor class EmbedInputs(nn.Module): diff --git a/tests/bm_test.py b/tests/bm_test.py index e38fd9659..28296cded 100644 --- a/tests/bm_test.py +++ b/tests/bm_test.py @@ -4,7 +4,6 @@ import pytest import torch from pytest_harvest import ResultsBag - from sbi.inference import FMPE, NLE, NPE, NPSE, NRE from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.trainers.npe import NPE_C diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 4a526e93c..b7afc5255 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,8 +7,9 @@ import pytest import torch - from sbi.neural_nets.embedding_nets import CNNEmbedding +from sbi.neural_nets.estimators.score_estimator import ( + ConditionalScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator @@ -144,3 +145,41 @@ def _build_score_estimator_and_tensors( ) condition = condition return score_estimator, inputs, condition + + +def test_times_schedule(): + id_net = torch.nn.Identity() + inpt_shape = (4,) + cond_shape = (4,) + + with pytest.raises(NotImplementedError): + ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) + + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) + obs = times.device + + assert exp == obs + assert times.shape == torch.Size((10,)) + + assert times[0, ...] >= vpse.t_min + assert times[-1, ...] <= vpse.t_max + + +def test_noise_schedule(): + id_net = torch.nn.Identity() + inpt_shape = (4,) + cond_shape = (4,) + + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) + noise = vpse.noise_schedule(times) + obs = noise.device + + assert exp == obs + assert noise.shape == torch.Size((10,)) + + assert noise[0, ...] >= vpse.beta_min + assert noise[-1, ...] <= vpse.beta_max