From abc67424d212e0a0d9ce6c47ccce2325346d1744 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Tue, 18 Mar 2025 16:21:13 +0100 Subject: [PATCH 01/18] refactoring noise_schedule and time schedule into base class - created noise_schedule method to be overwritten by derivatives - created times_schedules method to be overwritten by derivatives - created test on times_schedules - improvded docstrings --- sbi/neural_nets/estimators/score_estimator.py | 130 +++++++++++------- tests/score_estimator_test.py | 23 +++- 2 files changed, 99 insertions(+), 54 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 132d44d1d..a941e19af 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -6,9 +6,8 @@ from typing import Callable, Optional, Union import torch -from torch import Tensor, nn - from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator +from torch import Tensor, nn class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): @@ -42,10 +41,12 @@ def __init__( input_shape: torch.Size, condition_shape: torch.Size, weight_fn: Union[str, Callable] = "max_likelihood", + beta_min: float = 0.01, + beta_max: float = 10.0, mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-3, - t_max: float = 1.0, + t_max: float = 1.0 ) -> None: r"""Score estimator class that estimates the conditional score function, i.e., gradient of the density p(xt|x0). @@ -59,6 +60,12 @@ def __init__( - "identity": constant weights (1.), - "max_likelihood": weights proportional to the diffusion function, or - a custom function that returns a Callable. + beta_min: starting/minimal value for beta, the variance of the noise + beta_max: ending/maximal value for beta, the variance of the noise + mean_0: expected value of 0th step noise + scale_0: sigma of 0th step noise + t_min: smallest time step + t_max: largest time step """ super().__init__(net, input_shape, condition_shape) @@ -66,7 +73,14 @@ def __init__( # Set lambdas (variance weights) function. self._set_weight_fn(weight_fn) - # Min time for diffusion (0 can be numerically unstable). + # Store device of this module + self.device = net.device if hasattr(net, "device") else torch.device("cpu") + + # Min/max values for noise variance beta + self.beta_min = beta_min + self.beta_max = beta_max + + # Min/max values for limits to time self.t_min = t_min self.t_max = t_max @@ -152,6 +166,7 @@ def loss( input: Input variable i.e. theta. condition: Conditioning variable. times: SDE time variable in [t_min, t_max]. Uniformly sampled if None. + if None, will be filled by calling the time_schedule method control_variate: Whether to use a control variate to reduce the variance of the stochastic loss estimator. control_variate_threshold: Threshold for the control variate. If the std @@ -161,13 +176,12 @@ def loss( MSE between target score and network output, scaled by the weight function. """ + # update device if required + self.device = input.device if self.device != input.device else self.device + # Sample diffusion times. if times is None: - times = ( - torch.rand(input.shape[0], device=input.device) - * (self.t_max - self.t_min) - + self.t_min - ) + times = self.times_schedule(input.shape[0]) # Sample noise. eps = torch.randn_like(input) @@ -312,6 +326,47 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ raise NotImplementedError + def noise_schedule(self, times: Tensor) -> Tensor: + """ + Generate a beta schedule for mean scaling in variance-preserving + stochastic differential equations (SDEs). + + This method acts as a fallback in case derivative classes do not + implement it on their own. It calculates a linear beta schedule defined + by the input `times`, which represent the normalized time steps t ∈ [0, 1]. + + Args: + times (Tensor): + SDE times in [0, 1]. This tensor will be regenerated from self.times_schedule + + Returns: + Tensor: Generated beta schedule at a given time. + + """ + return self.beta_min + (self.beta_max - self.beta_min) * times + + def times_schedule(self, + num_samples: int, + t_min: float = None, + t_max: float = None) -> Tensor: + """ + Perform uniform sampling of time variables within the range [t_min, t_max]. + + Args: + num_samples (int): Number of samples to generate. + t_min (float, optional): The minimum time value. Defaults to self.t_min. + t_max (float, optional): The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. + + TODO: is the tensor on device? + """ + t_min = self.t_min if isinstance(t_min, type(None)) else t_min + t_max = self.t_max if isinstance(t_max, type(None)) else t_max + + return torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min + def _set_weight_fn(self, weight_fn: Union[str, Callable]): """Set the weight function. @@ -355,8 +410,6 @@ def __init__( t_min: float = 1e-5, t_max: float = 1.0, ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max super().__init__( net, input_shape, @@ -364,6 +417,8 @@ def __init__( mean_0=mean_0, std_0=std_0, weight_fn=weight_fn, + beta_min=beta_min, + beta_max=beta_max, t_min=t_min, t_max=t_max, ) @@ -399,17 +454,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return torch.sqrt(std) - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in variance preserving SDEs. - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance preserving SDEs. @@ -420,7 +464,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) return phi * input @@ -435,12 +479,14 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - g = torch.sqrt(self._beta_schedule(times)) + g = torch.sqrt(self.noise_schedule(times)) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) return g - +#TODO: experiment with time schedule with more samples around .5 (check edm paper) +#TODO: impacts on training and evaluate +#TODO: check effect in mini sbibm -> converges faster (focus on more important) class SubVPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with sub-variance preserving SDEs.""" @@ -457,13 +503,13 @@ def __init__( t_min: float = 1e-2, t_max: float = 1.0, ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max super().__init__( net, input_shape, condition_shape, weight_fn=weight_fn, + beta_min = beta_min, + beta_max = beta_max, mean_0=mean_0, std_0=std_0, t_min=t_min, @@ -501,18 +547,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in sub-variance preserving SDEs. - (Same as for variance preserving SDEs.) - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for sub-variance preserving SDEs. @@ -523,7 +557,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) @@ -542,7 +576,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ g = torch.sqrt( torch.abs( - self._beta_schedule(times) + self.noise_schedule(times) * ( 1 - torch.exp( @@ -612,17 +646,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _sigma_schedule(self, times: Tensor) -> Tensor: - """Geometric sigma schedule for variance exploding SDEs. - - Args: - times: SDE time variable in [0,1]. - - Returns: - Sigma schedule at a given time. - """ - return self.sigma_min * (self.sigma_max / self.sigma_min) ** times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance exploding SDEs. @@ -645,7 +668,8 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Diffusion function at a given time. """ - g = self._sigma_schedule(times) * math.sqrt( + sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** times + g = sigmas * math.sqrt( (2 * math.log(self.sigma_max / self.sigma_min)) ) diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 4a526e93c..7b6e09001 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,8 +7,9 @@ import pytest import torch - from sbi.neural_nets.embedding_nets import CNNEmbedding +from sbi.neural_nets.estimators.score_estimator import ( + ConditionalScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator @@ -144,3 +145,23 @@ def _build_score_estimator_and_tensors( ) condition = condition return score_estimator, inputs, condition + + +def test_times_schedule(): + + id_net = torch.nn.Identity() + inpt_shape = (4,) + cond_shape = (4,) + + with pytest.raises(NotImplementedError): + cse = ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) + + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) + obs = times.device + + assert exp == obs + assert times.shape == torch.Size((10,)) + assert times.max().item() < vpse.t_max + assert times.min().item() >= vpse.t_min From b7d3be3d86d529c040926d377733632db8c2b86d Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Tue, 18 Mar 2025 16:25:26 +0100 Subject: [PATCH 02/18] added noise schedule test --- tests/score_estimator_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 7b6e09001..686e61046 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -165,3 +165,21 @@ def test_times_schedule(): assert times.shape == torch.Size((10,)) assert times.max().item() < vpse.t_max assert times.min().item() >= vpse.t_min + + +def test_noise_schedule(): + + id_net = torch.nn.Identity() + inpt_shape = (4,) + cond_shape = (4,) + + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) + noise = vpse.noise_schedule(times) + obs = noise.device + + assert exp == obs + assert noise.shape == torch.Size((10,)) + assert noise.max().item() < vpse.beta_max + assert noise.min().item() >= vpse.beta_min From dca1939804c96c25a9b288e3058f48e0337cfff5 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Tue, 18 Mar 2025 18:09:51 +0100 Subject: [PATCH 03/18] implemented beta schedule for variance-preserving estimators - added tests too - inspired by https://arxiv.org/abs/2206.00364 --- sbi/neural_nets/estimators/score_estimator.py | 69 +++++++++++++++++-- tests/score_estimator_test.py | 12 +++- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index a941e19af..efc3691a8 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -2,11 +2,12 @@ # under the Apache License Version 2.0, see import math -from math import pi +from math import exp, pi from typing import Callable, Optional, Union import torch from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator +from scipy import stats from torch import Tensor, nn @@ -337,7 +338,8 @@ def noise_schedule(self, times: Tensor) -> Tensor: Args: times (Tensor): - SDE times in [0, 1]. This tensor will be regenerated from self.times_schedule + SDE times in [0, 1]. This tensor will be regenerated from + self.times_schedule Returns: Tensor: Generated beta schedule at a given time. @@ -351,6 +353,7 @@ def times_schedule(self, t_max: float = None) -> Tensor: """ Perform uniform sampling of time variables within the range [t_min, t_max]. + The `times` tensor will be put on the same device as the stored network. Args: num_samples (int): Number of samples to generate. @@ -360,7 +363,6 @@ def times_schedule(self, Returns: Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. - TODO: is the tensor on device? """ t_min = self.t_min if isinstance(t_min, type(None)) else t_min t_max = self.t_max if isinstance(t_max, type(None)) else t_max @@ -394,6 +396,9 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]): raise ValueError(f"Weight function {weight_fn} not recognized.") +#TODO: experiment with time schedule with more samples around .5 (check edm paper) +#TODO: impacts on training and evaluate +#TODO: check effect in mini sbibm -> converges faster (focus on more important) class VPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" @@ -409,7 +414,15 @@ def __init__( std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, t_max: float = 1.0, + pmean: float = 1.2, + pstd: float = -1.2 ) -> None: + + self.pmean, self.pstd = pmean, pstd + noise_dist = stats.norm(pmean, pstd**2) + self.beta_min = exp(noise_dist.ppf(0.01)) + self.beta_max = exp(noise_dist.ppf(0.99)) + super().__init__( net, input_shape, @@ -484,9 +497,53 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: g = g.unsqueeze(-1) return g -#TODO: experiment with time schedule with more samples around .5 (check edm paper) -#TODO: impacts on training and evaluate -#TODO: check effect in mini sbibm -> converges faster (focus on more important) + def noise_schedule(self, times: Tensor) -> Tensor: + """ + Generate a beta schedule similar to suggestions in the EDM [1] paper. + + This method acts as a fallback in case derivative classes do not + implement it on their own. It calculates a linear beta schedule defined + by the input `times`, which represent the normalized time steps t ∈ [0, 1]. + + Args: + times (Tensor): + SDE times in [0, 1]. This tensor will be regenerated from + self.times_schedule + + Returns: + Tensor: Generated beta schedule at a given time. + + [1] Karras et al "Elucidating the Design Space of Diffusion-Based + Generative Models", https://arxiv.org/abs/2206.00364 + """ + + samples = torch.randn_like(times)*(self.pstd**2) + self.pmean + return torch.exp(samples) + + def times_schedule(self, + num_samples: int, + t_min: float = None, + t_max: float = None) -> Tensor: + """ + Perform normal sampling around the middle of the interval [t_min, t_max] + + Args: + num_samples (int): Number of samples to generate. + t_min (float, optional): The minimum time value. Defaults to self.t_min. + t_max (float, optional): The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. + + """ + t_min = self.t_min if isinstance(t_min, type(None)) else t_min + t_max = self.t_max if isinstance(t_max, type(None)) else t_max + t_mu = t_min + (t_max - t_min)/2. + t_std = (t_max - t_min)/8. + + # apply scale and loc to normal distribution + return torch.randn(num_samples, device=self.device) * t_std + t_mu + class SubVPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with sub-variance preserving SDEs.""" diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 686e61046..988b94910 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -11,6 +11,7 @@ from sbi.neural_nets.estimators.score_estimator import ( ConditionalScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator +from scipy import stats @pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) @@ -161,10 +162,17 @@ def test_times_schedule(): times = vpse.times_schedule(10) obs = times.device + delta = vpse.t_max - vpse.t_min + t_mu = vpse.t_min + delta/2 + t_std = delta/8. + + ndist = stats.norm(t_mu, t_std) + lo,hi = ndist.ppf(.01), ndist.ppf(.99) + assert exp == obs assert times.shape == torch.Size((10,)) - assert times.max().item() < vpse.t_max - assert times.min().item() >= vpse.t_min + assert times.max().item() <= hi + assert times.min().item() >= lo def test_noise_schedule(): From 6ec57931373070dd507829470cbce08d215b4bb1 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Tue, 18 Mar 2025 18:13:28 +0100 Subject: [PATCH 04/18] code cosmetics triggered by ruff --- sbi/neural_nets/estimators/score_estimator.py | 9 ++++++--- tests/score_estimator_test.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index efc3691a8..615a0b073 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -6,10 +6,11 @@ from typing import Callable, Optional, Union import torch -from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator from scipy import stats from torch import Tensor, nn +from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator + class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): r"""Score matching for score-based generative models (e.g., denoising diffusion). @@ -361,7 +362,8 @@ def times_schedule(self, t_max (float, optional): The maximum time value. Defaults to self.t_max. Returns: - Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. + Tensor: A tensor of sampled time variables scaled and shifted to the + range [0,1]. """ t_min = self.t_min if isinstance(t_min, type(None)) else t_min @@ -533,7 +535,8 @@ def times_schedule(self, t_max (float, optional): The maximum time value. Defaults to self.t_max. Returns: - Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. + Tensor: A tensor of sampled time variables scaled and shifted to + the range [0,1]. """ t_min = self.t_min if isinstance(t_min, type(None)) else t_min diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 988b94910..3533d064e 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,11 +7,14 @@ import pytest import torch +from scipy import stats + from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, VPScoreEstimator) + ConditionalScoreEstimator, + VPScoreEstimator, +) from sbi.neural_nets.net_builders import build_score_estimator -from scipy import stats @pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) @@ -155,7 +158,7 @@ def test_times_schedule(): cond_shape = (4,) with pytest.raises(NotImplementedError): - cse = ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) + ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) exp = vpse.device From 3e963a9572948729a8dff888e58834bbb143dca6 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Wed, 19 Mar 2025 10:41:09 +0100 Subject: [PATCH 05/18] cloned VPScoreEstimator to yield improved version in addition: - added improved version to benchmarks (for later comparison) - created new class ImprovedVPScoreEstimator --- sbi/neural_nets/estimators/score_estimator.py | 106 +++++++++++++++++- sbi/neural_nets/net_builders/score_nets.py | 18 +-- tests/bm_test.py | 3 +- 3 files changed, 113 insertions(+), 14 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 615a0b073..cac19c1a7 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -6,11 +6,10 @@ from typing import Callable, Optional, Union import torch +from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator from scipy import stats from torch import Tensor, nn -from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator - class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): r"""Score matching for score-based generative models (e.g., denoising diffusion). @@ -397,7 +396,6 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]): else: raise ValueError(f"Weight function {weight_fn} not recognized.") - #TODO: experiment with time schedule with more samples around .5 (check edm paper) #TODO: impacts on training and evaluate #TODO: check effect in mini sbibm -> converges faster (focus on more important) @@ -499,6 +497,108 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: g = g.unsqueeze(-1) return g + +#TODO: experiment with time schedule with more samples around .5 (check edm paper) +#TODO: impacts on training and evaluate +#TODO: check effect in mini sbibm -> converges faster (focus on more important) +class ImprovedVPScoreEstimator(ConditionalScoreEstimator): + """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" + + def __init__( + self, + net: nn.Module, + input_shape: torch.Size, + condition_shape: torch.Size, + weight_fn: Union[str, Callable] = "max_likelihood", + beta_min: float = 0.01, + beta_max: float = 10.0, + mean_0: Union[Tensor, float] = 0.0, + std_0: Union[Tensor, float] = 1.0, + t_min: float = 1e-5, + t_max: float = 1.0, + pmean: float = 1.2, + pstd: float = -1.2 + ) -> None: + + self.pmean, self.pstd = pmean, pstd + noise_dist = stats.norm(pmean, pstd**2) + self.beta_min = exp(noise_dist.ppf(0.01)) + self.beta_max = exp(noise_dist.ppf(0.99)) + + super().__init__( + net, + input_shape, + condition_shape, + mean_0=mean_0, + std_0=std_0, + weight_fn=weight_fn, + beta_min=beta_min, + beta_max=beta_max, + t_min=t_min, + t_max=t_max, + ) + + def mean_t_fn(self, times: Tensor) -> Tensor: + """Conditional mean function for variance preserving SDEs. + Args: + times: SDE time variable in [0,1]. + + Returns: + Conditional mean at a given time. + """ + phi = torch.exp( + -0.25 * times**2.0 * (self.beta_max - self.beta_min) + - 0.5 * times * self.beta_min + ) + for _ in range(len(self.input_shape)): + phi = phi.unsqueeze(-1) + return phi + + def std_fn(self, times: Tensor) -> Tensor: + """Standard deviation function for variance preserving SDEs. + Args: + times: SDE time variable in [0,1]. + + Returns: + Standard deviation at a given time. + """ + std = 1.0 - torch.exp( + -0.5 * times**2.0 * (self.beta_max - self.beta_min) - times * self.beta_min + ) + for _ in range(len(self.input_shape)): + std = std.unsqueeze(-1) + return torch.sqrt(std) + + def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: + """Drift function for variance preserving SDEs. + + Args: + input: Original data, x0. + times: SDE time variable in [0,1]. + + Returns: + Drift function at a given time. + """ + phi = -0.5 * self.noise_schedule(times) + while len(phi.shape) < len(input.shape): + phi = phi.unsqueeze(-1) + return phi * input + + def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: + """Diffusion function for variance preserving SDEs. + + Args: + input: Original data, x0. + times: SDE time variable in [0,1]. + + Returns: + Drift function at a given time. + """ + g = torch.sqrt(self.noise_schedule(times)) + while len(g.shape) < len(input.shape): + g = g.unsqueeze(-1) + return g + def noise_schedule(self, times: Tensor) -> Tensor: """ Generate a beta schedule similar to suggestions in the EDM [1] paper. diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index 274c0b2c7..ba40ec940 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -2,17 +2,14 @@ import torch import torch.nn as nn -from torch import Tensor - from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - GaussianFourierTimeEmbedding, - SubVPScoreEstimator, - VEScoreEstimator, - VPScoreEstimator, -) -from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization + ConditionalScoreEstimator, GaussianFourierTimeEmbedding, + ImprovedVPScoreEstimator, SubVPScoreEstimator, VEScoreEstimator, + VPScoreEstimator) +from sbi.utils.sbiutils import (standardizing_net, z_score_parser, + z_standardization) from sbi.utils.user_input_checks import check_data_device +from torch import Tensor class EmbedInputs(nn.Module): @@ -115,6 +112,7 @@ def build_score_estimator( batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. sde_type: SDE type used, which defines the mean and std functions. One of: - 'vp': Variance preserving. + - 'vp++': Variance preserving. - 'subvp': Sub-variance preserving. - 've': Variance exploding. Defaults to 'vp'. @@ -194,6 +192,8 @@ def build_score_estimator( estimator = VEScoreEstimator elif sde_type == "subvp": estimator = SubVPScoreEstimator + elif sde_type == "vp++": + estimator = ImprovedVPScoreEstimator else: raise ValueError(f"SDE type: {sde_type} not supported.") diff --git a/tests/bm_test.py b/tests/bm_test.py index e38fd9659..c61a6fef6 100644 --- a/tests/bm_test.py +++ b/tests/bm_test.py @@ -4,7 +4,6 @@ import pytest import torch from pytest_harvest import ResultsBag - from sbi.inference import FMPE, NLE, NPE, NPSE, NRE from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.trainers.npe import NPE_C @@ -50,7 +49,7 @@ "npse": [ {"score_estimator": nn, "sde_type": sde} for nn in SCORE_ESTIMATORS - for sde in ["ve", "vp"] + for sde in ["ve", "vp", "vp++"] ], "snpe": [{}], "snle": [{}], From 4b625101e196dc29d07486dfb66416a05516efde Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Wed, 19 Mar 2025 11:57:10 +0100 Subject: [PATCH 06/18] more realistic bounds for unit test --- tests/score_estimator_test.py | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 3533d064e..ccbefce9d 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,17 +7,14 @@ import pytest import torch -from scipy import stats - from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - VPScoreEstimator, -) + ConditionalScoreEstimator, ImprovedVPScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator +from scipy import stats -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) @pytest.mark.parametrize("input_sample_dim", (1, 2)) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) @@ -46,7 +43,7 @@ def test_score_estimator_loss_shapes( @pytest.mark.gpu -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_score_estimator_on_device(sde_type, device): """Test whether DensityEstimators can be moved to the device.""" @@ -68,7 +65,7 @@ def test_score_estimator_on_device(sde_type, device): assert str(loss.device).split(":")[0] == device, "Loss device mismatch." -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) @pytest.mark.parametrize("input_sample_dim", (1, 2)) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) @@ -160,13 +157,13 @@ def test_times_schedule(): with pytest.raises(NotImplementedError): ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) - vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) - exp = vpse.device - times = vpse.times_schedule(10) + ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = ivpse.device + times = ivpse.times_schedule(10) obs = times.device - delta = vpse.t_max - vpse.t_min - t_mu = vpse.t_min + delta/2 + delta = ivpse.t_max - ivpse.t_min + t_mu = ivpse.t_min + delta/2. t_std = delta/8. ndist = stats.norm(t_mu, t_std) @@ -177,6 +174,9 @@ def test_times_schedule(): assert times.max().item() <= hi assert times.min().item() >= lo + assert times.max().item() < ivpse.t_max + assert times.min().item() > ivpse.t_min + def test_noise_schedule(): @@ -184,13 +184,13 @@ def test_noise_schedule(): inpt_shape = (4,) cond_shape = (4,) - vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) - exp = vpse.device - times = vpse.times_schedule(10) - noise = vpse.noise_schedule(times) + ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = ivpse.device + times = ivpse.times_schedule(10) + noise = ivpse.noise_schedule(times) obs = noise.device assert exp == obs assert noise.shape == torch.Size((10,)) - assert noise.max().item() < vpse.beta_max - assert noise.min().item() >= vpse.beta_min + assert noise.max().item() < ivpse.beta_max + assert noise.min().item() >= ivpse.beta_min From 274212ee857267aabb5ea70695b8784077cd589a Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Wed, 19 Mar 2025 11:57:24 +0100 Subject: [PATCH 07/18] typo and refactoring to understand how the VE estimator is implemented --- sbi/neural_nets/estimators/score_estimator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index cac19c1a7..03c39de11 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -180,7 +180,7 @@ def loss( # update device if required self.device = input.device if self.device != input.device else self.device - # Sample diffusion times. + # Sample times from the Markov chain if times is None: times = self.times_schedule(input.shape[0]) @@ -619,7 +619,7 @@ def noise_schedule(self, times: Tensor) -> Tensor: Generative Models", https://arxiv.org/abs/2206.00364 """ - samples = torch.randn_like(times)*(self.pstd**2) + self.pmean + samples = torch.randn_like(times)*(self.pstd) + self.pmean return torch.exp(samples) def times_schedule(self, @@ -828,9 +828,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Diffusion function at a given time. """ - sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** times + sigma_scale = self.sigma_max / self.sigma_min + sigmas = self.sigma_min * (sigma_scale) ** times g = sigmas * math.sqrt( - (2 * math.log(self.sigma_max / self.sigma_min)) + (2 * math.log(sigma_scale)) ) while len(g.shape) < len(input.shape): From 74065ae935a1914f60198dda01373ed3599294a9 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Wed, 19 Mar 2025 12:06:21 +0100 Subject: [PATCH 08/18] fixed wrong setup of pmean and pstd --- sbi/neural_nets/estimators/score_estimator.py | 4 ++-- tests/score_estimator_test.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 03c39de11..ea69a95c4 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -516,8 +516,8 @@ def __init__( std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, t_max: float = 1.0, - pmean: float = 1.2, - pstd: float = -1.2 + pmean: float = -1.2, + pstd: float = 1.2 ) -> None: self.pmean, self.pstd = pmean, pstd diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index ccbefce9d..6f55233be 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -166,14 +166,14 @@ def test_times_schedule(): t_mu = ivpse.t_min + delta/2. t_std = delta/8. + assert exp == obs + assert times.shape == torch.Size((10,)) + ndist = stats.norm(t_mu, t_std) lo,hi = ndist.ppf(.01), ndist.ppf(.99) - assert exp == obs - assert times.shape == torch.Size((10,)) assert times.max().item() <= hi assert times.min().item() >= lo - assert times.max().item() < ivpse.t_max assert times.min().item() > ivpse.t_min @@ -192,5 +192,9 @@ def test_noise_schedule(): assert exp == obs assert noise.shape == torch.Size((10,)) - assert noise.max().item() < ivpse.beta_max - assert noise.min().item() >= ivpse.beta_min + + ndist = stats.norm(ivpse.pmean, ivpse.pstd) + lo,hi = ndist.ppf(.01), ndist.ppf(.99) + + assert noise.max().item() < hi + assert noise.min().item() > lo From d61e198d36aa4ad97c7c21c5ec9b66203ffa3b39 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Wed, 19 Mar 2025 15:42:58 +0100 Subject: [PATCH 09/18] code reformatting --- sbi/neural_nets/estimators/score_estimator.py | 54 +++++++++---------- sbi/neural_nets/net_builders/score_nets.py | 13 +++-- tests/score_estimator_test.py | 15 +++--- 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index ea69a95c4..99803c04f 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -47,7 +47,7 @@ def __init__( mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-3, - t_max: float = 1.0 + t_max: float = 1.0, ) -> None: r"""Score estimator class that estimates the conditional score function, i.e., gradient of the density p(xt|x0). @@ -347,10 +347,9 @@ def noise_schedule(self, times: Tensor) -> Tensor: """ return self.beta_min + (self.beta_max - self.beta_min) * times - def times_schedule(self, - num_samples: int, - t_min: float = None, - t_max: float = None) -> Tensor: + def times_schedule( + self, num_samples: int, t_min: float = None, t_max: float = None + ) -> Tensor: """ Perform uniform sampling of time variables within the range [t_min, t_max]. The `times` tensor will be put on the same device as the stored network. @@ -396,9 +395,10 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]): else: raise ValueError(f"Weight function {weight_fn} not recognized.") -#TODO: experiment with time schedule with more samples around .5 (check edm paper) -#TODO: impacts on training and evaluate -#TODO: check effect in mini sbibm -> converges faster (focus on more important) + +# TODO: experiment with time schedule with more samples around .5 (check edm paper) +# TODO: impacts on training and evaluate +# TODO: check effect in mini sbibm -> converges faster (focus on more important) class VPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" @@ -414,10 +414,9 @@ def __init__( std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, t_max: float = 1.0, - pmean: float = 1.2, - pstd: float = -1.2 + pmean: float = 1.2, + pstd: float = -1.2, ) -> None: - self.pmean, self.pstd = pmean, pstd noise_dist = stats.norm(pmean, pstd**2) self.beta_min = exp(noise_dist.ppf(0.01)) @@ -498,9 +497,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return g -#TODO: experiment with time schedule with more samples around .5 (check edm paper) -#TODO: impacts on training and evaluate -#TODO: check effect in mini sbibm -> converges faster (focus on more important) +# TODO: experiment with time schedule with more samples around .5 (check edm paper) +# TODO: impacts on training and evaluate +# TODO: check effect in mini sbibm -> converges faster (focus on more important) class ImprovedVPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" @@ -516,10 +515,9 @@ def __init__( std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, t_max: float = 1.0, - pmean: float = -1.2, - pstd: float = 1.2 + pmean: float = -1.2, + pstd: float = 1.2, ) -> None: - self.pmean, self.pstd = pmean, pstd noise_dist = stats.norm(pmean, pstd**2) self.beta_min = exp(noise_dist.ppf(0.01)) @@ -619,13 +617,12 @@ def noise_schedule(self, times: Tensor) -> Tensor: Generative Models", https://arxiv.org/abs/2206.00364 """ - samples = torch.randn_like(times)*(self.pstd) + self.pmean + samples = torch.randn_like(times) * (self.pstd) + self.pmean return torch.exp(samples) - def times_schedule(self, - num_samples: int, - t_min: float = None, - t_max: float = None) -> Tensor: + def times_schedule( + self, num_samples: int, t_min: float = None, t_max: float = None + ) -> Tensor: """ Perform normal sampling around the middle of the interval [t_min, t_max] @@ -641,12 +638,13 @@ def times_schedule(self, """ t_min = self.t_min if isinstance(t_min, type(None)) else t_min t_max = self.t_max if isinstance(t_max, type(None)) else t_max - t_mu = t_min + (t_max - t_min)/2. - t_std = (t_max - t_min)/8. + t_mu = t_min + (t_max - t_min) / 2.0 + t_std = (t_max - t_min) / 8.0 # apply scale and loc to normal distribution return torch.randn(num_samples, device=self.device) * t_std + t_mu + class SubVPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with sub-variance preserving SDEs.""" @@ -668,8 +666,8 @@ def __init__( input_shape, condition_shape, weight_fn=weight_fn, - beta_min = beta_min, - beta_max = beta_max, + beta_min=beta_min, + beta_max=beta_max, mean_0=mean_0, std_0=std_0, t_min=t_min, @@ -830,9 +828,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ sigma_scale = self.sigma_max / self.sigma_min sigmas = self.sigma_min * (sigma_scale) ** times - g = sigmas * math.sqrt( - (2 * math.log(sigma_scale)) - ) + g = sigmas * math.sqrt((2 * math.log(sigma_scale))) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index ba40ec940..378eacfb1 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -3,11 +3,14 @@ import torch import torch.nn as nn from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, GaussianFourierTimeEmbedding, - ImprovedVPScoreEstimator, SubVPScoreEstimator, VEScoreEstimator, - VPScoreEstimator) -from sbi.utils.sbiutils import (standardizing_net, z_score_parser, - z_standardization) + ConditionalScoreEstimator, + GaussianFourierTimeEmbedding, + ImprovedVPScoreEstimator, + SubVPScoreEstimator, + VEScoreEstimator, + VPScoreEstimator, +) +from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization from sbi.utils.user_input_checks import check_data_device from torch import Tensor diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 6f55233be..987573614 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -9,7 +9,10 @@ import torch from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, ImprovedVPScoreEstimator, VPScoreEstimator) + ConditionalScoreEstimator, + ImprovedVPScoreEstimator, + VPScoreEstimator, +) from sbi.neural_nets.net_builders import build_score_estimator from scipy import stats @@ -149,7 +152,6 @@ def _build_score_estimator_and_tensors( def test_times_schedule(): - id_net = torch.nn.Identity() inpt_shape = (4,) cond_shape = (4,) @@ -163,14 +165,14 @@ def test_times_schedule(): obs = times.device delta = ivpse.t_max - ivpse.t_min - t_mu = ivpse.t_min + delta/2. - t_std = delta/8. + t_mu = ivpse.t_min + delta / 2.0 + t_std = delta / 8.0 assert exp == obs assert times.shape == torch.Size((10,)) ndist = stats.norm(t_mu, t_std) - lo,hi = ndist.ppf(.01), ndist.ppf(.99) + lo, hi = ndist.ppf(0.01), ndist.ppf(0.99) assert times.max().item() <= hi assert times.min().item() >= lo @@ -179,7 +181,6 @@ def test_times_schedule(): def test_noise_schedule(): - id_net = torch.nn.Identity() inpt_shape = (4,) cond_shape = (4,) @@ -194,7 +195,7 @@ def test_noise_schedule(): assert noise.shape == torch.Size((10,)) ndist = stats.norm(ivpse.pmean, ivpse.pstd) - lo,hi = ndist.ppf(.01), ndist.ppf(.99) + lo, hi = ndist.ppf(0.01), ndist.ppf(0.99) assert noise.max().item() < hi assert noise.min().item() > lo From d2c0100bc45e89a87c11a3c0009306cd6d8a9b69 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 16:53:01 +0100 Subject: [PATCH 10/18] use the time schedule for computing the validation scores --- sbi/inference/trainers/npse/npse.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py index c8b647e2e..ee5a236f4 100644 --- a/sbi/inference/trainers/npse/npse.py +++ b/sbi/inference/trainers/npse/npse.py @@ -5,17 +5,9 @@ from typing import Any, Callable, Optional, Union import torch -from torch import Tensor, ones -from torch.distributions import Distribution -from torch.nn.utils.clip_grad import clip_grad_norm_ -from torch.optim.adam import Adam -from torch.utils.tensorboard.writer import SummaryWriter - from sbi import utils as utils from sbi.inference import NeuralInference -from sbi.inference.posteriors import ( - DirectPosterior, -) +from sbi.inference.posteriors import DirectPosterior from sbi.inference.posteriors.score_posterior import ScorePosterior from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator from sbi.neural_nets.factory import posterior_score_nn @@ -29,6 +21,11 @@ ) from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior from sbi.utils.torchutils import assert_all_finite +from torch import Tensor, ones +from torch.distributions import Distribution +from torch.nn.utils.clip_grad import clip_grad_norm_ +from torch.optim.adam import Adam +from torch.utils.tensorboard.writer import SummaryWriter class NPSE(NeuralInference): @@ -304,9 +301,14 @@ def default_calibration_kernel(x): self._neural_net.to(self._device) if isinstance(validation_times, int): - validation_times = torch.linspace( - self._neural_net.t_min, self._neural_net.t_max, validation_times - ) + if hasattr(self._neural_net, "times_schedule"): + validation_times = self._neural_net.times_schedule(validation_times) + else: + validation_times = torch.linspace( + self._neural_net.t_min, + self._neural_net.t_max, + steps=validation_times, + ) assert isinstance( validation_times, Tensor ) # let pyright know validation_times is a Tensor. From 5b506731e020026359efd1d0aef3a3d266896d80 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 16:53:44 +0100 Subject: [PATCH 11/18] propagate name change --- sbi/neural_nets/net_builders/score_nets.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index 378eacfb1..e33499ab9 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -3,14 +3,11 @@ import torch import torch.nn as nn from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - GaussianFourierTimeEmbedding, - ImprovedVPScoreEstimator, - SubVPScoreEstimator, - VEScoreEstimator, - VPScoreEstimator, -) -from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization + ConditionalScoreEstimator, GaussianFourierTimeEmbedding, + ImprovedScoreEstimator, SubVPScoreEstimator, VEScoreEstimator, + VPScoreEstimator) +from sbi.utils.sbiutils import (standardizing_net, z_score_parser, + z_standardization) from sbi.utils.user_input_checks import check_data_device from torch import Tensor @@ -196,7 +193,7 @@ def build_score_estimator( elif sde_type == "subvp": estimator = SubVPScoreEstimator elif sde_type == "vp++": - estimator = ImprovedVPScoreEstimator + estimator = ImprovedScoreEstimator else: raise ValueError(f"SDE type: {sde_type} not supported.") From 2081ae01f42a96467f8c95fc5641ee17026fa0e1 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 16:53:56 +0100 Subject: [PATCH 12/18] fix unit tests to respect new schedules --- tests/score_estimator_test.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 987573614..7a698d66c 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -9,10 +9,7 @@ import torch from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - ImprovedVPScoreEstimator, - VPScoreEstimator, -) + ConditionalScoreEstimator, ImprovedScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator from scipy import stats @@ -159,25 +156,19 @@ def test_times_schedule(): with pytest.raises(NotImplementedError): ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) - ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape) + ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape) exp = ivpse.device times = ivpse.times_schedule(10) obs = times.device - delta = ivpse.t_max - ivpse.t_min - t_mu = ivpse.t_min + delta / 2.0 - t_std = delta / 8.0 - assert exp == obs assert times.shape == torch.Size((10,)) - ndist = stats.norm(t_mu, t_std) - lo, hi = ndist.ppf(0.01), ndist.ppf(0.99) + assert times[0 ,...] != ivpse.t_min + assert times[-1,...] != ivpse.t_max - assert times.max().item() <= hi - assert times.min().item() >= lo - assert times.max().item() < ivpse.t_max - assert times.min().item() > ivpse.t_min + assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max])) + assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min])) def test_noise_schedule(): @@ -185,7 +176,7 @@ def test_noise_schedule(): inpt_shape = (4,) cond_shape = (4,) - ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape) + ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape) exp = ivpse.device times = ivpse.times_schedule(10) noise = ivpse.noise_schedule(times) @@ -193,9 +184,4 @@ def test_noise_schedule(): assert exp == obs assert noise.shape == torch.Size((10,)) - - ndist = stats.norm(ivpse.pmean, ivpse.pstd) - lo, hi = ndist.ppf(0.01), ndist.ppf(0.99) - - assert noise.max().item() < hi - assert noise.min().item() > lo + assert torch.allclose(times,noise) From 4f58be6dfb5ddb1d831eb89e5edcaa17d7c18147 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 16:54:52 +0100 Subject: [PATCH 13/18] comply with formatting --- sbi/inference/trainers/npse/npse.py | 11 +-- sbi/neural_nets/estimators/score_estimator.py | 79 +++++++++++-------- sbi/neural_nets/net_builders/score_nets.py | 16 ++-- tests/bm_test.py | 1 + tests/score_estimator_test.py | 12 +-- 5 files changed, 69 insertions(+), 50 deletions(-) diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py index ee5a236f4..d3c5d05a3 100644 --- a/sbi/inference/trainers/npse/npse.py +++ b/sbi/inference/trainers/npse/npse.py @@ -5,6 +5,12 @@ from typing import Any, Callable, Optional, Union import torch +from torch import Tensor, ones +from torch.distributions import Distribution +from torch.nn.utils.clip_grad import clip_grad_norm_ +from torch.optim.adam import Adam +from torch.utils.tensorboard.writer import SummaryWriter + from sbi import utils as utils from sbi.inference import NeuralInference from sbi.inference.posteriors import DirectPosterior @@ -21,11 +27,6 @@ ) from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior from sbi.utils.torchutils import assert_all_finite -from torch import Tensor, ones -from torch.distributions import Distribution -from torch.nn.utils.clip_grad import clip_grad_norm_ -from torch.optim.adam import Adam -from torch.utils.tensorboard.writer import SummaryWriter class NPSE(NeuralInference): diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 99803c04f..c712c8a06 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -6,10 +6,11 @@ from typing import Callable, Optional, Union import torch -from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator from scipy import stats from torch import Tensor, nn +from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator + class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): r"""Score matching for score-based generative models (e.g., denoising diffusion). @@ -134,7 +135,7 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: # Time dependent z-scoring! Keeps input at similar scales input_enc = (input - mean) / std - # Approximate score becoming exact for t -> t_max, "skip connection" + # Approximate score becoming exact for t -> t_max, "skip connection" (c_skip in edm) score_gaussian = (input - mean) / std**2 # Score prediction by the network @@ -329,13 +330,15 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: def noise_schedule(self, times: Tensor) -> Tensor: """ - Generate a beta schedule for mean scaling in variance-preserving - stochastic differential equations (SDEs). + Generate a beta schedule in stochastic differential equations (SDEs). + This will be used for sampling. This method acts as a fallback in case derivative classes do not implement it on their own. It calculates a linear beta schedule defined by the input `times`, which represent the normalized time steps t ∈ [0, 1]. + We implement a linear noise schedule here. + Args: times (Tensor): SDE times in [0, 1]. This tensor will be regenerated from @@ -351,9 +354,12 @@ def times_schedule( self, num_samples: int, t_min: float = None, t_max: float = None ) -> Tensor: """ + Construction time samples for evaluating the diffusion model. Perform uniform sampling of time variables within the range [t_min, t_max]. The `times` tensor will be put on the same device as the stored network. + We implement a uniformly sampled time stepping here. + Args: num_samples (int): Number of samples to generate. t_min (float, optional): The minimum time value. Defaults to self.t_min. @@ -396,9 +402,6 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]): raise ValueError(f"Weight function {weight_fn} not recognized.") -# TODO: experiment with time schedule with more samples around .5 (check edm paper) -# TODO: impacts on training and evaluate -# TODO: check effect in mini sbibm -> converges faster (focus on more important) class VPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" @@ -497,11 +500,12 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return g -# TODO: experiment with time schedule with more samples around .5 (check edm paper) -# TODO: impacts on training and evaluate -# TODO: check effect in mini sbibm -> converges faster (focus on more important) -class ImprovedVPScoreEstimator(ConditionalScoreEstimator): - """Class for score estimators with variance preserving SDEs (i.e., DDPM).""" +class ImprovedScoreEstimator(ConditionalScoreEstimator): + """Implement EDM-like score matching estimator as in [1] + + [1] Karras et al "Elucidating the Design Space of Diffusion-Based + Generative Models", https://arxiv.org/abs/2206.00364 + """ def __init__( self, @@ -509,19 +513,23 @@ def __init__( input_shape: torch.Size, condition_shape: torch.Size, weight_fn: Union[str, Callable] = "max_likelihood", - beta_min: float = 0.01, - beta_max: float = 10.0, + beta_min: float = 0.002, # sigma_min in the paper + beta_max: float = 80.0, # sigma_max in the paper mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, - t_min: float = 1e-5, - t_max: float = 1.0, - pmean: float = -1.2, - pstd: float = 1.2, + t_min: float = 1e-5, # will be ignored due to EDM setup + t_max: float = 1.0, # + pmean: float = -1.2, # mean of noise scheme for training + pstd: float = 1.2, # std of noise scheme for training + sigma_data: float = 0.5, ) -> None: self.pmean, self.pstd = pmean, pstd noise_dist = stats.norm(pmean, pstd**2) - self.beta_min = exp(noise_dist.ppf(0.01)) - self.beta_max = exp(noise_dist.ppf(0.99)) + + self.sigma_min = exp(noise_dist.ppf(0.01)) + self.sigma_max = exp(noise_dist.ppf(0.99)) + + self.rho = 7 super().__init__( net, @@ -537,9 +545,9 @@ def __init__( ) def mean_t_fn(self, times: Tensor) -> Tensor: - """Conditional mean function for variance preserving SDEs. + """Conditional mean function for EDM-style DMs. Args: - times: SDE time variable in [0,1]. + times: time variable in [0,1]. Returns: Conditional mean at a given time. @@ -553,9 +561,9 @@ def mean_t_fn(self, times: Tensor) -> Tensor: return phi def std_fn(self, times: Tensor) -> Tensor: - """Standard deviation function for variance preserving SDEs. + """Standard deviation function for EDM style DMs. Args: - times: SDE time variable in [0,1]. + times: time variable in [0,1]. Returns: Standard deviation at a given time. @@ -616,15 +624,13 @@ def noise_schedule(self, times: Tensor) -> Tensor: [1] Karras et al "Elucidating the Design Space of Diffusion-Based Generative Models", https://arxiv.org/abs/2206.00364 """ - - samples = torch.randn_like(times) * (self.pstd) + self.pmean - return torch.exp(samples) + return times def times_schedule( self, num_samples: int, t_min: float = None, t_max: float = None ) -> Tensor: """ - Perform normal sampling around the middle of the interval [t_min, t_max] + Construct time samples as suggested in EDM paper [1]. Args: num_samples (int): Number of samples to generate. @@ -635,14 +641,16 @@ def times_schedule( Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1]. + [1] Karras et al "Elucidating the Design Space of Diffusion-Based + Generative Models", https://arxiv.org/abs/2206.00364 """ - t_min = self.t_min if isinstance(t_min, type(None)) else t_min - t_max = self.t_max if isinstance(t_max, type(None)) else t_max - t_mu = t_min + (t_max - t_min) / 2.0 - t_std = (t_max - t_min) / 8.0 + times = torch.linspace(0.0, 1.0, steps=num_samples) + inv_rho = 1.0 / self.rho - # apply scale and loc to normal distribution - return torch.randn(num_samples, device=self.device) * t_std + t_mu + beta_scale = self.beta_max ** (inv_rho) - self.beta_min ** (inv_rho) + offset = self.beta_min ** (inv_rho) + + return (offset + beta_scale * times) ** (self.rho) class SubVPScoreEstimator(ConditionalScoreEstimator): @@ -836,6 +844,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return g +# TODO: try to add a EDM-like estimator + + class GaussianFourierTimeEmbedding(nn.Module): """Gaussian random features for encoding time steps. diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index e33499ab9..e407e9e5a 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -2,14 +2,18 @@ import torch import torch.nn as nn +from torch import Tensor + from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, GaussianFourierTimeEmbedding, - ImprovedScoreEstimator, SubVPScoreEstimator, VEScoreEstimator, - VPScoreEstimator) -from sbi.utils.sbiutils import (standardizing_net, z_score_parser, - z_standardization) + ConditionalScoreEstimator, + GaussianFourierTimeEmbedding, + ImprovedScoreEstimator, + SubVPScoreEstimator, + VEScoreEstimator, + VPScoreEstimator, +) +from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization from sbi.utils.user_input_checks import check_data_device -from torch import Tensor class EmbedInputs(nn.Module): diff --git a/tests/bm_test.py b/tests/bm_test.py index c61a6fef6..26cae3dc2 100644 --- a/tests/bm_test.py +++ b/tests/bm_test.py @@ -4,6 +4,7 @@ import pytest import torch from pytest_harvest import ResultsBag + from sbi.inference import FMPE, NLE, NPE, NPSE, NRE from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.trainers.npe import NPE_C diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 7a698d66c..2f5dfe88c 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,11 +7,13 @@ import pytest import torch + from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, ImprovedScoreEstimator, VPScoreEstimator) + ConditionalScoreEstimator, + ImprovedScoreEstimator, +) from sbi.neural_nets.net_builders import build_score_estimator -from scipy import stats @pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) @@ -164,8 +166,8 @@ def test_times_schedule(): assert exp == obs assert times.shape == torch.Size((10,)) - assert times[0 ,...] != ivpse.t_min - assert times[-1,...] != ivpse.t_max + assert times[0, ...] != ivpse.t_min + assert times[-1, ...] != ivpse.t_max assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max])) assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min])) @@ -184,4 +186,4 @@ def test_noise_schedule(): assert exp == obs assert noise.shape == torch.Size((10,)) - assert torch.allclose(times,noise) + assert torch.allclose(times, noise) From 773a28bd2bb7a2ee73b3f05ce0e53b684d22b1c1 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 17:50:51 +0100 Subject: [PATCH 14/18] attempted to implement EDM-like diffusion - without touching the forward function of ConditionalScoreEstimator - benchmarks show that this leads to very long training time without any performance improvements --- sbi/neural_nets/estimators/score_estimator.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index c712c8a06..8df7dbaed 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -6,11 +6,10 @@ from typing import Callable, Optional, Union import torch +from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator from scipy import stats from torch import Tensor, nn -from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator - class ConditionalScoreEstimator(ConditionalVectorFieldEstimator): r"""Score matching for score-based generative models (e.g., denoising diffusion). @@ -129,13 +128,14 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: std = self.approx_marginal_std(time) # As input to the neural net we want to have something that changes proportianl - # to how the scores change + # to how the scores change (a la c_noise in edm) time_enc = self.std_fn(time) - # Time dependent z-scoring! Keeps input at similar scales + # Time dependent z-scoring! Keeps input at similar scales (c_in in edm) input_enc = (input - mean) / std - # Approximate score becoming exact for t -> t_max, "skip connection" (c_skip in edm) + # Approximate score becoming exact for t -> t_max, "skip connection" + # (a la c_skip in edm) score_gaussian = (input - mean) / std**2 # Score prediction by the network @@ -145,6 +145,7 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: # The learnable part will be largly scaled at the beginning of the diffusion # and the gaussian part (where it should end up) will dominate at the end of # the diffusion. + # (a la c_out in edm) scale = self.mean_t_fn(time) / self.std_fn(time) output_score = -scale * score_pred - score_gaussian @@ -416,14 +417,8 @@ def __init__( mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, - t_max: float = 1.0, - pmean: float = 1.2, - pstd: float = -1.2, + t_max: float = 1.0 ) -> None: - self.pmean, self.pstd = pmean, pstd - noise_dist = stats.norm(pmean, pstd**2) - self.beta_min = exp(noise_dist.ppf(0.01)) - self.beta_max = exp(noise_dist.ppf(0.99)) super().__init__( net, @@ -515,6 +510,7 @@ def __init__( weight_fn: Union[str, Callable] = "max_likelihood", beta_min: float = 0.002, # sigma_min in the paper beta_max: float = 80.0, # sigma_max in the paper + beta_data: float = .5, #sigma_data in the paper mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-5, # will be ignored due to EDM setup @@ -523,12 +519,15 @@ def __init__( pstd: float = 1.2, # std of noise scheme for training sigma_data: float = 0.5, ) -> None: + + + #TODO: store sigma values for training in extra field self.pmean, self.pstd = pmean, pstd noise_dist = stats.norm(pmean, pstd**2) - self.sigma_min = exp(noise_dist.ppf(0.01)) self.sigma_max = exp(noise_dist.ppf(0.99)) + self.beta_data = beta_data #sigma data from edm paper self.rho = 7 super().__init__( @@ -546,34 +545,33 @@ def __init__( def mean_t_fn(self, times: Tensor) -> Tensor: """Conditional mean function for EDM-style DMs. + This is required to model c_in. + Args: times: time variable in [0,1]. Returns: Conditional mean at a given time. """ - phi = torch.exp( - -0.25 * times**2.0 * (self.beta_max - self.beta_min) - - 0.5 * times * self.beta_min - ) + noise = self.noise_schedule(times) + phi = 1./torch.sqrt(noise**2 + self.beta_data**2) for _ in range(len(self.input_shape)): phi = phi.unsqueeze(-1) return phi def std_fn(self, times: Tensor) -> Tensor: """Standard deviation function for EDM style DMs. + This is akin to c_noise in the network/precond parametrisation. Args: times: time variable in [0,1]. Returns: Standard deviation at a given time. """ - std = 1.0 - torch.exp( - -0.5 * times**2.0 * (self.beta_max - self.beta_min) - times * self.beta_min - ) + std = .25*torch.log(times) for _ in range(len(self.input_shape)): std = std.unsqueeze(-1) - return torch.sqrt(std) + return std def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance preserving SDEs. From 6386fc8479d9a2f440d6a32e3659e1ba412f1298 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 18:17:00 +0100 Subject: [PATCH 15/18] removed "improved" denoising network --- tests/bm_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/bm_test.py b/tests/bm_test.py index 26cae3dc2..28296cded 100644 --- a/tests/bm_test.py +++ b/tests/bm_test.py @@ -4,7 +4,6 @@ import pytest import torch from pytest_harvest import ResultsBag - from sbi.inference import FMPE, NLE, NPE, NPSE, NRE from sbi.inference.posteriors.base_posterior import NeuralPosterior from sbi.inference.trainers.npe import NPE_C @@ -50,7 +49,7 @@ "npse": [ {"score_estimator": nn, "sde_type": sde} for nn in SCORE_ESTIMATORS - for sde in ["ve", "vp", "vp++"] + for sde in ["ve", "vp"] ], "snpe": [{}], "snle": [{}], From 2bbd92fbf0787e069dcb2e460ebe342d83475186 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 18:17:28 +0100 Subject: [PATCH 16/18] consolidated tests --- tests/score_estimator_test.py | 36 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py index 2f5dfe88c..b7afc5255 100644 --- a/tests/score_estimator_test.py +++ b/tests/score_estimator_test.py @@ -7,16 +7,13 @@ import pytest import torch - from sbi.neural_nets.embedding_nets import CNNEmbedding from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - ImprovedScoreEstimator, -) + ConditionalScoreEstimator, VPScoreEstimator) from sbi.neural_nets.net_builders import build_score_estimator -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) @pytest.mark.parametrize("input_sample_dim", (1, 2)) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) @@ -45,7 +42,7 @@ def test_score_estimator_loss_shapes( @pytest.mark.gpu -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_score_estimator_on_device(sde_type, device): """Test whether DensityEstimators can be moved to the device.""" @@ -67,7 +64,7 @@ def test_score_estimator_on_device(sde_type, device): assert str(loss.device).split(":")[0] == device, "Loss device mismatch." -@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"]) +@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"]) @pytest.mark.parametrize("input_sample_dim", (1, 2)) @pytest.mark.parametrize("input_event_shape", ((1,), (4,))) @pytest.mark.parametrize("condition_event_shape", ((1,), (7,))) @@ -158,19 +155,16 @@ def test_times_schedule(): with pytest.raises(NotImplementedError): ConditionalScoreEstimator(id_net, inpt_shape, cond_shape) - ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape) - exp = ivpse.device - times = ivpse.times_schedule(10) + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) obs = times.device assert exp == obs assert times.shape == torch.Size((10,)) - assert times[0, ...] != ivpse.t_min - assert times[-1, ...] != ivpse.t_max - - assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max])) - assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min])) + assert times[0, ...] >= vpse.t_min + assert times[-1, ...] <= vpse.t_max def test_noise_schedule(): @@ -178,12 +172,14 @@ def test_noise_schedule(): inpt_shape = (4,) cond_shape = (4,) - ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape) - exp = ivpse.device - times = ivpse.times_schedule(10) - noise = ivpse.noise_schedule(times) + vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape) + exp = vpse.device + times = vpse.times_schedule(10) + noise = vpse.noise_schedule(times) obs = noise.device assert exp == obs assert noise.shape == torch.Size((10,)) - assert torch.allclose(times, noise) + + assert noise[0, ...] >= vpse.beta_min + assert noise[-1, ...] <= vpse.beta_max From 8f0c65a58c2df694dcdaaf540d37da3207f84b64 Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 18:17:44 +0100 Subject: [PATCH 17/18] removed occurrances of vp++ --- sbi/neural_nets/net_builders/score_nets.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py index e407e9e5a..5780163c7 100644 --- a/sbi/neural_nets/net_builders/score_nets.py +++ b/sbi/neural_nets/net_builders/score_nets.py @@ -2,18 +2,13 @@ import torch import torch.nn as nn -from torch import Tensor - from sbi.neural_nets.estimators.score_estimator import ( - ConditionalScoreEstimator, - GaussianFourierTimeEmbedding, - ImprovedScoreEstimator, - SubVPScoreEstimator, - VEScoreEstimator, - VPScoreEstimator, -) -from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization + ConditionalScoreEstimator, GaussianFourierTimeEmbedding, + SubVPScoreEstimator, VEScoreEstimator, VPScoreEstimator) +from sbi.utils.sbiutils import (standardizing_net, z_score_parser, + z_standardization) from sbi.utils.user_input_checks import check_data_device +from torch import Tensor class EmbedInputs(nn.Module): @@ -116,7 +111,6 @@ def build_score_estimator( batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring. sde_type: SDE type used, which defines the mean and std functions. One of: - 'vp': Variance preserving. - - 'vp++': Variance preserving. - 'subvp': Sub-variance preserving. - 've': Variance exploding. Defaults to 'vp'. @@ -196,8 +190,6 @@ def build_score_estimator( estimator = VEScoreEstimator elif sde_type == "subvp": estimator = SubVPScoreEstimator - elif sde_type == "vp++": - estimator = ImprovedScoreEstimator else: raise ValueError(f"SDE type: {sde_type} not supported.") From 833c17b595c14176f77946c4142409e3c8141fca Mon Sep 17 00:00:00 2001 From: Peter Steinbach Date: Thu, 20 Mar 2025 18:19:20 +0100 Subject: [PATCH 18/18] removed all mentions of edm --- sbi/neural_nets/estimators/score_estimator.py | 176 +----------------- 1 file changed, 10 insertions(+), 166 deletions(-) diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index 8df7dbaed..acacc7d67 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -128,14 +128,13 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: std = self.approx_marginal_std(time) # As input to the neural net we want to have something that changes proportianl - # to how the scores change (a la c_noise in edm) + # to how the scores change time_enc = self.std_fn(time) - # Time dependent z-scoring! Keeps input at similar scales (c_in in edm) + # Time dependent z-scoring! Keeps input at similar scales input_enc = (input - mean) / std # Approximate score becoming exact for t -> t_max, "skip connection" - # (a la c_skip in edm) score_gaussian = (input - mean) / std**2 # Score prediction by the network @@ -145,7 +144,6 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: # The learnable part will be largly scaled at the beginning of the diffusion # and the gaussian part (where it should end up) will dominate at the end of # the diffusion. - # (a la c_out in edm) scale = self.mean_t_fn(time) / self.std_fn(time) output_score = -scale * score_pred - score_gaussian @@ -182,7 +180,7 @@ def loss( # update device if required self.device = input.device if self.device != input.device else self.device - # Sample times from the Markov chain + # Sample times from the Markov chain, use batch dimension if times is None: times = self.times_schedule(input.shape[0]) @@ -355,7 +353,7 @@ def times_schedule( self, num_samples: int, t_min: float = None, t_max: float = None ) -> Tensor: """ - Construction time samples for evaluating the diffusion model. + Time samples for evaluating the diffusion model. Perform uniform sampling of time variables within the range [t_min, t_max]. The `times` tensor will be put on the same device as the stored network. @@ -374,7 +372,12 @@ def times_schedule( t_min = self.t_min if isinstance(t_min, type(None)) else t_min t_max = self.t_max if isinstance(t_max, type(None)) else t_max - return torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min + times = torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min + + # t_min and t_max need to be part of the sequence + times[0,...] = t_min + times[-1,...] = t_max + return torch.Tensor(sorted(times)) def _set_weight_fn(self, weight_fn: Union[str, Callable]): """Set the weight function. @@ -495,162 +498,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return g -class ImprovedScoreEstimator(ConditionalScoreEstimator): - """Implement EDM-like score matching estimator as in [1] - - [1] Karras et al "Elucidating the Design Space of Diffusion-Based - Generative Models", https://arxiv.org/abs/2206.00364 - """ - - def __init__( - self, - net: nn.Module, - input_shape: torch.Size, - condition_shape: torch.Size, - weight_fn: Union[str, Callable] = "max_likelihood", - beta_min: float = 0.002, # sigma_min in the paper - beta_max: float = 80.0, # sigma_max in the paper - beta_data: float = .5, #sigma_data in the paper - mean_0: Union[Tensor, float] = 0.0, - std_0: Union[Tensor, float] = 1.0, - t_min: float = 1e-5, # will be ignored due to EDM setup - t_max: float = 1.0, # - pmean: float = -1.2, # mean of noise scheme for training - pstd: float = 1.2, # std of noise scheme for training - sigma_data: float = 0.5, - ) -> None: - - - #TODO: store sigma values for training in extra field - self.pmean, self.pstd = pmean, pstd - noise_dist = stats.norm(pmean, pstd**2) - self.sigma_min = exp(noise_dist.ppf(0.01)) - self.sigma_max = exp(noise_dist.ppf(0.99)) - - self.beta_data = beta_data #sigma data from edm paper - self.rho = 7 - - super().__init__( - net, - input_shape, - condition_shape, - mean_0=mean_0, - std_0=std_0, - weight_fn=weight_fn, - beta_min=beta_min, - beta_max=beta_max, - t_min=t_min, - t_max=t_max, - ) - - def mean_t_fn(self, times: Tensor) -> Tensor: - """Conditional mean function for EDM-style DMs. - This is required to model c_in. - - Args: - times: time variable in [0,1]. - - Returns: - Conditional mean at a given time. - """ - noise = self.noise_schedule(times) - phi = 1./torch.sqrt(noise**2 + self.beta_data**2) - for _ in range(len(self.input_shape)): - phi = phi.unsqueeze(-1) - return phi - - def std_fn(self, times: Tensor) -> Tensor: - """Standard deviation function for EDM style DMs. - This is akin to c_noise in the network/precond parametrisation. - Args: - times: time variable in [0,1]. - - Returns: - Standard deviation at a given time. - """ - std = .25*torch.log(times) - for _ in range(len(self.input_shape)): - std = std.unsqueeze(-1) - return std - - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: - """Drift function for variance preserving SDEs. - - Args: - input: Original data, x0. - times: SDE time variable in [0,1]. - - Returns: - Drift function at a given time. - """ - phi = -0.5 * self.noise_schedule(times) - while len(phi.shape) < len(input.shape): - phi = phi.unsqueeze(-1) - return phi * input - - def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: - """Diffusion function for variance preserving SDEs. - - Args: - input: Original data, x0. - times: SDE time variable in [0,1]. - - Returns: - Drift function at a given time. - """ - g = torch.sqrt(self.noise_schedule(times)) - while len(g.shape) < len(input.shape): - g = g.unsqueeze(-1) - return g - - def noise_schedule(self, times: Tensor) -> Tensor: - """ - Generate a beta schedule similar to suggestions in the EDM [1] paper. - - This method acts as a fallback in case derivative classes do not - implement it on their own. It calculates a linear beta schedule defined - by the input `times`, which represent the normalized time steps t ∈ [0, 1]. - - Args: - times (Tensor): - SDE times in [0, 1]. This tensor will be regenerated from - self.times_schedule - - Returns: - Tensor: Generated beta schedule at a given time. - - [1] Karras et al "Elucidating the Design Space of Diffusion-Based - Generative Models", https://arxiv.org/abs/2206.00364 - """ - return times - - def times_schedule( - self, num_samples: int, t_min: float = None, t_max: float = None - ) -> Tensor: - """ - Construct time samples as suggested in EDM paper [1]. - - Args: - num_samples (int): Number of samples to generate. - t_min (float, optional): The minimum time value. Defaults to self.t_min. - t_max (float, optional): The maximum time value. Defaults to self.t_max. - - Returns: - Tensor: A tensor of sampled time variables scaled and shifted to - the range [0,1]. - - [1] Karras et al "Elucidating the Design Space of Diffusion-Based - Generative Models", https://arxiv.org/abs/2206.00364 - """ - times = torch.linspace(0.0, 1.0, steps=num_samples) - inv_rho = 1.0 / self.rho - - beta_scale = self.beta_max ** (inv_rho) - self.beta_min ** (inv_rho) - offset = self.beta_min ** (inv_rho) - - return (offset + beta_scale * times) ** (self.rho) - - class SubVPScoreEstimator(ConditionalScoreEstimator): """Class for score estimators with sub-variance preserving SDEs.""" @@ -842,9 +689,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: return g -# TODO: try to add a EDM-like estimator - - class GaussianFourierTimeEmbedding(nn.Module): """Gaussian random features for encoding time steps.