From abc67424d212e0a0d9ce6c47ccce2325346d1744 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Tue, 18 Mar 2025 16:21:13 +0100
Subject: [PATCH 01/18] refactoring noise_schedule and time schedule into base
class
- created noise_schedule method to be overwritten by derivatives
- created times_schedules method to be overwritten by derivatives
- created test on times_schedules
- improvded docstrings
---
sbi/neural_nets/estimators/score_estimator.py | 130 +++++++++++-------
tests/score_estimator_test.py | 23 +++-
2 files changed, 99 insertions(+), 54 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index 132d44d1d..a941e19af 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -6,9 +6,8 @@
from typing import Callable, Optional, Union
import torch
-from torch import Tensor, nn
-
from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
+from torch import Tensor, nn
class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
@@ -42,10 +41,12 @@ def __init__(
input_shape: torch.Size,
condition_shape: torch.Size,
weight_fn: Union[str, Callable] = "max_likelihood",
+ beta_min: float = 0.01,
+ beta_max: float = 10.0,
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-3,
- t_max: float = 1.0,
+ t_max: float = 1.0
) -> None:
r"""Score estimator class that estimates the conditional score function, i.e.,
gradient of the density p(xt|x0).
@@ -59,6 +60,12 @@ def __init__(
- "identity": constant weights (1.),
- "max_likelihood": weights proportional to the diffusion function, or
- a custom function that returns a Callable.
+ beta_min: starting/minimal value for beta, the variance of the noise
+ beta_max: ending/maximal value for beta, the variance of the noise
+ mean_0: expected value of 0th step noise
+ scale_0: sigma of 0th step noise
+ t_min: smallest time step
+ t_max: largest time step
"""
super().__init__(net, input_shape, condition_shape)
@@ -66,7 +73,14 @@ def __init__(
# Set lambdas (variance weights) function.
self._set_weight_fn(weight_fn)
- # Min time for diffusion (0 can be numerically unstable).
+ # Store device of this module
+ self.device = net.device if hasattr(net, "device") else torch.device("cpu")
+
+ # Min/max values for noise variance beta
+ self.beta_min = beta_min
+ self.beta_max = beta_max
+
+ # Min/max values for limits to time
self.t_min = t_min
self.t_max = t_max
@@ -152,6 +166,7 @@ def loss(
input: Input variable i.e. theta.
condition: Conditioning variable.
times: SDE time variable in [t_min, t_max]. Uniformly sampled if None.
+ if None, will be filled by calling the time_schedule method
control_variate: Whether to use a control variate to reduce the variance of
the stochastic loss estimator.
control_variate_threshold: Threshold for the control variate. If the std
@@ -161,13 +176,12 @@ def loss(
MSE between target score and network output, scaled by the weight function.
"""
+ # update device if required
+ self.device = input.device if self.device != input.device else self.device
+
# Sample diffusion times.
if times is None:
- times = (
- torch.rand(input.shape[0], device=input.device)
- * (self.t_max - self.t_min)
- + self.t_min
- )
+ times = self.times_schedule(input.shape[0])
# Sample noise.
eps = torch.randn_like(input)
@@ -312,6 +326,47 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
raise NotImplementedError
+ def noise_schedule(self, times: Tensor) -> Tensor:
+ """
+ Generate a beta schedule for mean scaling in variance-preserving
+ stochastic differential equations (SDEs).
+
+ This method acts as a fallback in case derivative classes do not
+ implement it on their own. It calculates a linear beta schedule defined
+ by the input `times`, which represent the normalized time steps t ∈ [0, 1].
+
+ Args:
+ times (Tensor):
+ SDE times in [0, 1]. This tensor will be regenerated from self.times_schedule
+
+ Returns:
+ Tensor: Generated beta schedule at a given time.
+
+ """
+ return self.beta_min + (self.beta_max - self.beta_min) * times
+
+ def times_schedule(self,
+ num_samples: int,
+ t_min: float = None,
+ t_max: float = None) -> Tensor:
+ """
+ Perform uniform sampling of time variables within the range [t_min, t_max].
+
+ Args:
+ num_samples (int): Number of samples to generate.
+ t_min (float, optional): The minimum time value. Defaults to self.t_min.
+ t_max (float, optional): The maximum time value. Defaults to self.t_max.
+
+ Returns:
+ Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1].
+
+ TODO: is the tensor on device?
+ """
+ t_min = self.t_min if isinstance(t_min, type(None)) else t_min
+ t_max = self.t_max if isinstance(t_max, type(None)) else t_max
+
+ return torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min
+
def _set_weight_fn(self, weight_fn: Union[str, Callable]):
"""Set the weight function.
@@ -355,8 +410,6 @@ def __init__(
t_min: float = 1e-5,
t_max: float = 1.0,
) -> None:
- self.beta_min = beta_min
- self.beta_max = beta_max
super().__init__(
net,
input_shape,
@@ -364,6 +417,8 @@ def __init__(
mean_0=mean_0,
std_0=std_0,
weight_fn=weight_fn,
+ beta_min=beta_min,
+ beta_max=beta_max,
t_min=t_min,
t_max=t_max,
)
@@ -399,17 +454,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return torch.sqrt(std)
- def _beta_schedule(self, times: Tensor) -> Tensor:
- """Linear beta schedule for mean scaling in variance preserving SDEs.
-
- Args:
- times: SDE time variable in [0,1].
-
- Returns:
- Beta schedule at a given time.
- """
- return self.beta_min + (self.beta_max - self.beta_min) * times
-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance preserving SDEs.
@@ -420,7 +464,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
- phi = -0.5 * self._beta_schedule(times)
+ phi = -0.5 * self.noise_schedule(times)
while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
return phi * input
@@ -435,12 +479,14 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
- g = torch.sqrt(self._beta_schedule(times))
+ g = torch.sqrt(self.noise_schedule(times))
while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)
return g
-
+#TODO: experiment with time schedule with more samples around .5 (check edm paper)
+#TODO: impacts on training and evaluate
+#TODO: check effect in mini sbibm -> converges faster (focus on more important)
class SubVPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with sub-variance preserving SDEs."""
@@ -457,13 +503,13 @@ def __init__(
t_min: float = 1e-2,
t_max: float = 1.0,
) -> None:
- self.beta_min = beta_min
- self.beta_max = beta_max
super().__init__(
net,
input_shape,
condition_shape,
weight_fn=weight_fn,
+ beta_min = beta_min,
+ beta_max = beta_max,
mean_0=mean_0,
std_0=std_0,
t_min=t_min,
@@ -501,18 +547,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std
- def _beta_schedule(self, times: Tensor) -> Tensor:
- """Linear beta schedule for mean scaling in sub-variance preserving SDEs.
- (Same as for variance preserving SDEs.)
-
- Args:
- times: SDE time variable in [0,1].
-
- Returns:
- Beta schedule at a given time.
- """
- return self.beta_min + (self.beta_max - self.beta_min) * times
-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for sub-variance preserving SDEs.
@@ -523,7 +557,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
- phi = -0.5 * self._beta_schedule(times)
+ phi = -0.5 * self.noise_schedule(times)
while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
@@ -542,7 +576,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
g = torch.sqrt(
torch.abs(
- self._beta_schedule(times)
+ self.noise_schedule(times)
* (
1
- torch.exp(
@@ -612,17 +646,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std
- def _sigma_schedule(self, times: Tensor) -> Tensor:
- """Geometric sigma schedule for variance exploding SDEs.
-
- Args:
- times: SDE time variable in [0,1].
-
- Returns:
- Sigma schedule at a given time.
- """
- return self.sigma_min * (self.sigma_max / self.sigma_min) ** times
-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance exploding SDEs.
@@ -645,7 +668,8 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Diffusion function at a given time.
"""
- g = self._sigma_schedule(times) * math.sqrt(
+ sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** times
+ g = sigmas * math.sqrt(
(2 * math.log(self.sigma_max / self.sigma_min))
)
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 4a526e93c..7b6e09001 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -7,8 +7,9 @@
import pytest
import torch
-
from sbi.neural_nets.embedding_nets import CNNEmbedding
+from sbi.neural_nets.estimators.score_estimator import (
+ ConditionalScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator
@@ -144,3 +145,23 @@ def _build_score_estimator_and_tensors(
)
condition = condition
return score_estimator, inputs, condition
+
+
+def test_times_schedule():
+
+ id_net = torch.nn.Identity()
+ inpt_shape = (4,)
+ cond_shape = (4,)
+
+ with pytest.raises(NotImplementedError):
+ cse = ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
+
+ vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = vpse.device
+ times = vpse.times_schedule(10)
+ obs = times.device
+
+ assert exp == obs
+ assert times.shape == torch.Size((10,))
+ assert times.max().item() < vpse.t_max
+ assert times.min().item() >= vpse.t_min
From b7d3be3d86d529c040926d377733632db8c2b86d Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Tue, 18 Mar 2025 16:25:26 +0100
Subject: [PATCH 02/18] added noise schedule test
---
tests/score_estimator_test.py | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 7b6e09001..686e61046 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -165,3 +165,21 @@ def test_times_schedule():
assert times.shape == torch.Size((10,))
assert times.max().item() < vpse.t_max
assert times.min().item() >= vpse.t_min
+
+
+def test_noise_schedule():
+
+ id_net = torch.nn.Identity()
+ inpt_shape = (4,)
+ cond_shape = (4,)
+
+ vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = vpse.device
+ times = vpse.times_schedule(10)
+ noise = vpse.noise_schedule(times)
+ obs = noise.device
+
+ assert exp == obs
+ assert noise.shape == torch.Size((10,))
+ assert noise.max().item() < vpse.beta_max
+ assert noise.min().item() >= vpse.beta_min
From dca1939804c96c25a9b288e3058f48e0337cfff5 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Tue, 18 Mar 2025 18:09:51 +0100
Subject: [PATCH 03/18] implemented beta schedule for variance-preserving
estimators
- added tests too
- inspired by https://arxiv.org/abs/2206.00364
---
sbi/neural_nets/estimators/score_estimator.py | 69 +++++++++++++++++--
tests/score_estimator_test.py | 12 +++-
2 files changed, 73 insertions(+), 8 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index a941e19af..efc3691a8 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -2,11 +2,12 @@
# under the Apache License Version 2.0, see
import math
-from math import pi
+from math import exp, pi
from typing import Callable, Optional, Union
import torch
from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
+from scipy import stats
from torch import Tensor, nn
@@ -337,7 +338,8 @@ def noise_schedule(self, times: Tensor) -> Tensor:
Args:
times (Tensor):
- SDE times in [0, 1]. This tensor will be regenerated from self.times_schedule
+ SDE times in [0, 1]. This tensor will be regenerated from
+ self.times_schedule
Returns:
Tensor: Generated beta schedule at a given time.
@@ -351,6 +353,7 @@ def times_schedule(self,
t_max: float = None) -> Tensor:
"""
Perform uniform sampling of time variables within the range [t_min, t_max].
+ The `times` tensor will be put on the same device as the stored network.
Args:
num_samples (int): Number of samples to generate.
@@ -360,7 +363,6 @@ def times_schedule(self,
Returns:
Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1].
- TODO: is the tensor on device?
"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max
@@ -394,6 +396,9 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]):
raise ValueError(f"Weight function {weight_fn} not recognized.")
+#TODO: experiment with time schedule with more samples around .5 (check edm paper)
+#TODO: impacts on training and evaluate
+#TODO: check effect in mini sbibm -> converges faster (focus on more important)
class VPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
@@ -409,7 +414,15 @@ def __init__(
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
t_max: float = 1.0,
+ pmean: float = 1.2,
+ pstd: float = -1.2
) -> None:
+
+ self.pmean, self.pstd = pmean, pstd
+ noise_dist = stats.norm(pmean, pstd**2)
+ self.beta_min = exp(noise_dist.ppf(0.01))
+ self.beta_max = exp(noise_dist.ppf(0.99))
+
super().__init__(
net,
input_shape,
@@ -484,9 +497,53 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
g = g.unsqueeze(-1)
return g
-#TODO: experiment with time schedule with more samples around .5 (check edm paper)
-#TODO: impacts on training and evaluate
-#TODO: check effect in mini sbibm -> converges faster (focus on more important)
+ def noise_schedule(self, times: Tensor) -> Tensor:
+ """
+ Generate a beta schedule similar to suggestions in the EDM [1] paper.
+
+ This method acts as a fallback in case derivative classes do not
+ implement it on their own. It calculates a linear beta schedule defined
+ by the input `times`, which represent the normalized time steps t ∈ [0, 1].
+
+ Args:
+ times (Tensor):
+ SDE times in [0, 1]. This tensor will be regenerated from
+ self.times_schedule
+
+ Returns:
+ Tensor: Generated beta schedule at a given time.
+
+ [1] Karras et al "Elucidating the Design Space of Diffusion-Based
+ Generative Models", https://arxiv.org/abs/2206.00364
+ """
+
+ samples = torch.randn_like(times)*(self.pstd**2) + self.pmean
+ return torch.exp(samples)
+
+ def times_schedule(self,
+ num_samples: int,
+ t_min: float = None,
+ t_max: float = None) -> Tensor:
+ """
+ Perform normal sampling around the middle of the interval [t_min, t_max]
+
+ Args:
+ num_samples (int): Number of samples to generate.
+ t_min (float, optional): The minimum time value. Defaults to self.t_min.
+ t_max (float, optional): The maximum time value. Defaults to self.t_max.
+
+ Returns:
+ Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1].
+
+ """
+ t_min = self.t_min if isinstance(t_min, type(None)) else t_min
+ t_max = self.t_max if isinstance(t_max, type(None)) else t_max
+ t_mu = t_min + (t_max - t_min)/2.
+ t_std = (t_max - t_min)/8.
+
+ # apply scale and loc to normal distribution
+ return torch.randn(num_samples, device=self.device) * t_std + t_mu
+
class SubVPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with sub-variance preserving SDEs."""
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 686e61046..988b94910 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -11,6 +11,7 @@
from sbi.neural_nets.estimators.score_estimator import (
ConditionalScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator
+from scipy import stats
@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
@@ -161,10 +162,17 @@ def test_times_schedule():
times = vpse.times_schedule(10)
obs = times.device
+ delta = vpse.t_max - vpse.t_min
+ t_mu = vpse.t_min + delta/2
+ t_std = delta/8.
+
+ ndist = stats.norm(t_mu, t_std)
+ lo,hi = ndist.ppf(.01), ndist.ppf(.99)
+
assert exp == obs
assert times.shape == torch.Size((10,))
- assert times.max().item() < vpse.t_max
- assert times.min().item() >= vpse.t_min
+ assert times.max().item() <= hi
+ assert times.min().item() >= lo
def test_noise_schedule():
From 6ec57931373070dd507829470cbce08d215b4bb1 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Tue, 18 Mar 2025 18:13:28 +0100
Subject: [PATCH 04/18] code cosmetics triggered by ruff
---
sbi/neural_nets/estimators/score_estimator.py | 9 ++++++---
tests/score_estimator_test.py | 9 ++++++---
2 files changed, 12 insertions(+), 6 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index efc3691a8..615a0b073 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -6,10 +6,11 @@
from typing import Callable, Optional, Union
import torch
-from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
from scipy import stats
from torch import Tensor, nn
+from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
+
class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
r"""Score matching for score-based generative models (e.g., denoising diffusion).
@@ -361,7 +362,8 @@ def times_schedule(self,
t_max (float, optional): The maximum time value. Defaults to self.t_max.
Returns:
- Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1].
+ Tensor: A tensor of sampled time variables scaled and shifted to the
+ range [0,1].
"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
@@ -533,7 +535,8 @@ def times_schedule(self,
t_max (float, optional): The maximum time value. Defaults to self.t_max.
Returns:
- Tensor: A tensor of sampled time variables scaled and shifted to the range [0,1].
+ Tensor: A tensor of sampled time variables scaled and shifted to
+ the range [0,1].
"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 988b94910..3533d064e 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -7,11 +7,14 @@
import pytest
import torch
+from scipy import stats
+
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator, VPScoreEstimator)
+ ConditionalScoreEstimator,
+ VPScoreEstimator,
+)
from sbi.neural_nets.net_builders import build_score_estimator
-from scipy import stats
@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
@@ -155,7 +158,7 @@ def test_times_schedule():
cond_shape = (4,)
with pytest.raises(NotImplementedError):
- cse = ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
+ ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
exp = vpse.device
From 3e963a9572948729a8dff888e58834bbb143dca6 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Wed, 19 Mar 2025 10:41:09 +0100
Subject: [PATCH 05/18] cloned VPScoreEstimator to yield improved version
in addition:
- added improved version to benchmarks (for later comparison)
- created new class ImprovedVPScoreEstimator
---
sbi/neural_nets/estimators/score_estimator.py | 106 +++++++++++++++++-
sbi/neural_nets/net_builders/score_nets.py | 18 +--
tests/bm_test.py | 3 +-
3 files changed, 113 insertions(+), 14 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index 615a0b073..cac19c1a7 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -6,11 +6,10 @@
from typing import Callable, Optional, Union
import torch
+from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
from scipy import stats
from torch import Tensor, nn
-from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
-
class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
r"""Score matching for score-based generative models (e.g., denoising diffusion).
@@ -397,7 +396,6 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]):
else:
raise ValueError(f"Weight function {weight_fn} not recognized.")
-
#TODO: experiment with time schedule with more samples around .5 (check edm paper)
#TODO: impacts on training and evaluate
#TODO: check effect in mini sbibm -> converges faster (focus on more important)
@@ -499,6 +497,108 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
g = g.unsqueeze(-1)
return g
+
+#TODO: experiment with time schedule with more samples around .5 (check edm paper)
+#TODO: impacts on training and evaluate
+#TODO: check effect in mini sbibm -> converges faster (focus on more important)
+class ImprovedVPScoreEstimator(ConditionalScoreEstimator):
+ """Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
+
+ def __init__(
+ self,
+ net: nn.Module,
+ input_shape: torch.Size,
+ condition_shape: torch.Size,
+ weight_fn: Union[str, Callable] = "max_likelihood",
+ beta_min: float = 0.01,
+ beta_max: float = 10.0,
+ mean_0: Union[Tensor, float] = 0.0,
+ std_0: Union[Tensor, float] = 1.0,
+ t_min: float = 1e-5,
+ t_max: float = 1.0,
+ pmean: float = 1.2,
+ pstd: float = -1.2
+ ) -> None:
+
+ self.pmean, self.pstd = pmean, pstd
+ noise_dist = stats.norm(pmean, pstd**2)
+ self.beta_min = exp(noise_dist.ppf(0.01))
+ self.beta_max = exp(noise_dist.ppf(0.99))
+
+ super().__init__(
+ net,
+ input_shape,
+ condition_shape,
+ mean_0=mean_0,
+ std_0=std_0,
+ weight_fn=weight_fn,
+ beta_min=beta_min,
+ beta_max=beta_max,
+ t_min=t_min,
+ t_max=t_max,
+ )
+
+ def mean_t_fn(self, times: Tensor) -> Tensor:
+ """Conditional mean function for variance preserving SDEs.
+ Args:
+ times: SDE time variable in [0,1].
+
+ Returns:
+ Conditional mean at a given time.
+ """
+ phi = torch.exp(
+ -0.25 * times**2.0 * (self.beta_max - self.beta_min)
+ - 0.5 * times * self.beta_min
+ )
+ for _ in range(len(self.input_shape)):
+ phi = phi.unsqueeze(-1)
+ return phi
+
+ def std_fn(self, times: Tensor) -> Tensor:
+ """Standard deviation function for variance preserving SDEs.
+ Args:
+ times: SDE time variable in [0,1].
+
+ Returns:
+ Standard deviation at a given time.
+ """
+ std = 1.0 - torch.exp(
+ -0.5 * times**2.0 * (self.beta_max - self.beta_min) - times * self.beta_min
+ )
+ for _ in range(len(self.input_shape)):
+ std = std.unsqueeze(-1)
+ return torch.sqrt(std)
+
+ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
+ """Drift function for variance preserving SDEs.
+
+ Args:
+ input: Original data, x0.
+ times: SDE time variable in [0,1].
+
+ Returns:
+ Drift function at a given time.
+ """
+ phi = -0.5 * self.noise_schedule(times)
+ while len(phi.shape) < len(input.shape):
+ phi = phi.unsqueeze(-1)
+ return phi * input
+
+ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
+ """Diffusion function for variance preserving SDEs.
+
+ Args:
+ input: Original data, x0.
+ times: SDE time variable in [0,1].
+
+ Returns:
+ Drift function at a given time.
+ """
+ g = torch.sqrt(self.noise_schedule(times))
+ while len(g.shape) < len(input.shape):
+ g = g.unsqueeze(-1)
+ return g
+
def noise_schedule(self, times: Tensor) -> Tensor:
"""
Generate a beta schedule similar to suggestions in the EDM [1] paper.
diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py
index 274c0b2c7..ba40ec940 100644
--- a/sbi/neural_nets/net_builders/score_nets.py
+++ b/sbi/neural_nets/net_builders/score_nets.py
@@ -2,17 +2,14 @@
import torch
import torch.nn as nn
-from torch import Tensor
-
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- GaussianFourierTimeEmbedding,
- SubVPScoreEstimator,
- VEScoreEstimator,
- VPScoreEstimator,
-)
-from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
+ ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
+ ImprovedVPScoreEstimator, SubVPScoreEstimator, VEScoreEstimator,
+ VPScoreEstimator)
+from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
+ z_standardization)
from sbi.utils.user_input_checks import check_data_device
+from torch import Tensor
class EmbedInputs(nn.Module):
@@ -115,6 +112,7 @@ def build_score_estimator(
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
sde_type: SDE type used, which defines the mean and std functions. One of:
- 'vp': Variance preserving.
+ - 'vp++': Variance preserving.
- 'subvp': Sub-variance preserving.
- 've': Variance exploding.
Defaults to 'vp'.
@@ -194,6 +192,8 @@ def build_score_estimator(
estimator = VEScoreEstimator
elif sde_type == "subvp":
estimator = SubVPScoreEstimator
+ elif sde_type == "vp++":
+ estimator = ImprovedVPScoreEstimator
else:
raise ValueError(f"SDE type: {sde_type} not supported.")
diff --git a/tests/bm_test.py b/tests/bm_test.py
index e38fd9659..c61a6fef6 100644
--- a/tests/bm_test.py
+++ b/tests/bm_test.py
@@ -4,7 +4,6 @@
import pytest
import torch
from pytest_harvest import ResultsBag
-
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.trainers.npe import NPE_C
@@ -50,7 +49,7 @@
"npse": [
{"score_estimator": nn, "sde_type": sde}
for nn in SCORE_ESTIMATORS
- for sde in ["ve", "vp"]
+ for sde in ["ve", "vp", "vp++"]
],
"snpe": [{}],
"snle": [{}],
From 4b625101e196dc29d07486dfb66416a05516efde Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Wed, 19 Mar 2025 11:57:10 +0100
Subject: [PATCH 06/18] more realistic bounds for unit test
---
tests/score_estimator_test.py | 38 +++++++++++++++++------------------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 3533d064e..ccbefce9d 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -7,17 +7,14 @@
import pytest
import torch
-from scipy import stats
-
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- VPScoreEstimator,
-)
+ ConditionalScoreEstimator, ImprovedVPScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator
+from scipy import stats
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
@pytest.mark.parametrize("input_sample_dim", (1, 2))
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@@ -46,7 +43,7 @@ def test_score_estimator_loss_shapes(
@pytest.mark.gpu
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_score_estimator_on_device(sde_type, device):
"""Test whether DensityEstimators can be moved to the device."""
@@ -68,7 +65,7 @@ def test_score_estimator_on_device(sde_type, device):
assert str(loss.device).split(":")[0] == device, "Loss device mismatch."
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
@pytest.mark.parametrize("input_sample_dim", (1, 2))
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@@ -160,13 +157,13 @@ def test_times_schedule():
with pytest.raises(NotImplementedError):
ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
- vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
- exp = vpse.device
- times = vpse.times_schedule(10)
+ ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = ivpse.device
+ times = ivpse.times_schedule(10)
obs = times.device
- delta = vpse.t_max - vpse.t_min
- t_mu = vpse.t_min + delta/2
+ delta = ivpse.t_max - ivpse.t_min
+ t_mu = ivpse.t_min + delta/2.
t_std = delta/8.
ndist = stats.norm(t_mu, t_std)
@@ -177,6 +174,9 @@ def test_times_schedule():
assert times.max().item() <= hi
assert times.min().item() >= lo
+ assert times.max().item() < ivpse.t_max
+ assert times.min().item() > ivpse.t_min
+
def test_noise_schedule():
@@ -184,13 +184,13 @@ def test_noise_schedule():
inpt_shape = (4,)
cond_shape = (4,)
- vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
- exp = vpse.device
- times = vpse.times_schedule(10)
- noise = vpse.noise_schedule(times)
+ ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = ivpse.device
+ times = ivpse.times_schedule(10)
+ noise = ivpse.noise_schedule(times)
obs = noise.device
assert exp == obs
assert noise.shape == torch.Size((10,))
- assert noise.max().item() < vpse.beta_max
- assert noise.min().item() >= vpse.beta_min
+ assert noise.max().item() < ivpse.beta_max
+ assert noise.min().item() >= ivpse.beta_min
From 274212ee857267aabb5ea70695b8784077cd589a Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Wed, 19 Mar 2025 11:57:24 +0100
Subject: [PATCH 07/18] typo and refactoring
to understand how the VE estimator is implemented
---
sbi/neural_nets/estimators/score_estimator.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index cac19c1a7..03c39de11 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -180,7 +180,7 @@ def loss(
# update device if required
self.device = input.device if self.device != input.device else self.device
- # Sample diffusion times.
+ # Sample times from the Markov chain
if times is None:
times = self.times_schedule(input.shape[0])
@@ -619,7 +619,7 @@ def noise_schedule(self, times: Tensor) -> Tensor:
Generative Models", https://arxiv.org/abs/2206.00364
"""
- samples = torch.randn_like(times)*(self.pstd**2) + self.pmean
+ samples = torch.randn_like(times)*(self.pstd) + self.pmean
return torch.exp(samples)
def times_schedule(self,
@@ -828,9 +828,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Diffusion function at a given time.
"""
- sigmas = self.sigma_min * (self.sigma_max / self.sigma_min) ** times
+ sigma_scale = self.sigma_max / self.sigma_min
+ sigmas = self.sigma_min * (sigma_scale) ** times
g = sigmas * math.sqrt(
- (2 * math.log(self.sigma_max / self.sigma_min))
+ (2 * math.log(sigma_scale))
)
while len(g.shape) < len(input.shape):
From 74065ae935a1914f60198dda01373ed3599294a9 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Wed, 19 Mar 2025 12:06:21 +0100
Subject: [PATCH 08/18] fixed wrong setup of pmean and pstd
---
sbi/neural_nets/estimators/score_estimator.py | 4 ++--
tests/score_estimator_test.py | 14 +++++++++-----
2 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index 03c39de11..ea69a95c4 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -516,8 +516,8 @@ def __init__(
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
t_max: float = 1.0,
- pmean: float = 1.2,
- pstd: float = -1.2
+ pmean: float = -1.2,
+ pstd: float = 1.2
) -> None:
self.pmean, self.pstd = pmean, pstd
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index ccbefce9d..6f55233be 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -166,14 +166,14 @@ def test_times_schedule():
t_mu = ivpse.t_min + delta/2.
t_std = delta/8.
+ assert exp == obs
+ assert times.shape == torch.Size((10,))
+
ndist = stats.norm(t_mu, t_std)
lo,hi = ndist.ppf(.01), ndist.ppf(.99)
- assert exp == obs
- assert times.shape == torch.Size((10,))
assert times.max().item() <= hi
assert times.min().item() >= lo
-
assert times.max().item() < ivpse.t_max
assert times.min().item() > ivpse.t_min
@@ -192,5 +192,9 @@ def test_noise_schedule():
assert exp == obs
assert noise.shape == torch.Size((10,))
- assert noise.max().item() < ivpse.beta_max
- assert noise.min().item() >= ivpse.beta_min
+
+ ndist = stats.norm(ivpse.pmean, ivpse.pstd)
+ lo,hi = ndist.ppf(.01), ndist.ppf(.99)
+
+ assert noise.max().item() < hi
+ assert noise.min().item() > lo
From d61e198d36aa4ad97c7c21c5ec9b66203ffa3b39 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Wed, 19 Mar 2025 15:42:58 +0100
Subject: [PATCH 09/18] code reformatting
---
sbi/neural_nets/estimators/score_estimator.py | 54 +++++++++----------
sbi/neural_nets/net_builders/score_nets.py | 13 +++--
tests/score_estimator_test.py | 15 +++---
3 files changed, 41 insertions(+), 41 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index ea69a95c4..99803c04f 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -47,7 +47,7 @@ def __init__(
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-3,
- t_max: float = 1.0
+ t_max: float = 1.0,
) -> None:
r"""Score estimator class that estimates the conditional score function, i.e.,
gradient of the density p(xt|x0).
@@ -347,10 +347,9 @@ def noise_schedule(self, times: Tensor) -> Tensor:
"""
return self.beta_min + (self.beta_max - self.beta_min) * times
- def times_schedule(self,
- num_samples: int,
- t_min: float = None,
- t_max: float = None) -> Tensor:
+ def times_schedule(
+ self, num_samples: int, t_min: float = None, t_max: float = None
+ ) -> Tensor:
"""
Perform uniform sampling of time variables within the range [t_min, t_max].
The `times` tensor will be put on the same device as the stored network.
@@ -396,9 +395,10 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]):
else:
raise ValueError(f"Weight function {weight_fn} not recognized.")
-#TODO: experiment with time schedule with more samples around .5 (check edm paper)
-#TODO: impacts on training and evaluate
-#TODO: check effect in mini sbibm -> converges faster (focus on more important)
+
+# TODO: experiment with time schedule with more samples around .5 (check edm paper)
+# TODO: impacts on training and evaluate
+# TODO: check effect in mini sbibm -> converges faster (focus on more important)
class VPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
@@ -414,10 +414,9 @@ def __init__(
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
t_max: float = 1.0,
- pmean: float = 1.2,
- pstd: float = -1.2
+ pmean: float = 1.2,
+ pstd: float = -1.2,
) -> None:
-
self.pmean, self.pstd = pmean, pstd
noise_dist = stats.norm(pmean, pstd**2)
self.beta_min = exp(noise_dist.ppf(0.01))
@@ -498,9 +497,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
return g
-#TODO: experiment with time schedule with more samples around .5 (check edm paper)
-#TODO: impacts on training and evaluate
-#TODO: check effect in mini sbibm -> converges faster (focus on more important)
+# TODO: experiment with time schedule with more samples around .5 (check edm paper)
+# TODO: impacts on training and evaluate
+# TODO: check effect in mini sbibm -> converges faster (focus on more important)
class ImprovedVPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
@@ -516,10 +515,9 @@ def __init__(
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
t_max: float = 1.0,
- pmean: float = -1.2,
- pstd: float = 1.2
+ pmean: float = -1.2,
+ pstd: float = 1.2,
) -> None:
-
self.pmean, self.pstd = pmean, pstd
noise_dist = stats.norm(pmean, pstd**2)
self.beta_min = exp(noise_dist.ppf(0.01))
@@ -619,13 +617,12 @@ def noise_schedule(self, times: Tensor) -> Tensor:
Generative Models", https://arxiv.org/abs/2206.00364
"""
- samples = torch.randn_like(times)*(self.pstd) + self.pmean
+ samples = torch.randn_like(times) * (self.pstd) + self.pmean
return torch.exp(samples)
- def times_schedule(self,
- num_samples: int,
- t_min: float = None,
- t_max: float = None) -> Tensor:
+ def times_schedule(
+ self, num_samples: int, t_min: float = None, t_max: float = None
+ ) -> Tensor:
"""
Perform normal sampling around the middle of the interval [t_min, t_max]
@@ -641,12 +638,13 @@ def times_schedule(self,
"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max
- t_mu = t_min + (t_max - t_min)/2.
- t_std = (t_max - t_min)/8.
+ t_mu = t_min + (t_max - t_min) / 2.0
+ t_std = (t_max - t_min) / 8.0
# apply scale and loc to normal distribution
return torch.randn(num_samples, device=self.device) * t_std + t_mu
+
class SubVPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with sub-variance preserving SDEs."""
@@ -668,8 +666,8 @@ def __init__(
input_shape,
condition_shape,
weight_fn=weight_fn,
- beta_min = beta_min,
- beta_max = beta_max,
+ beta_min=beta_min,
+ beta_max=beta_max,
mean_0=mean_0,
std_0=std_0,
t_min=t_min,
@@ -830,9 +828,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
sigma_scale = self.sigma_max / self.sigma_min
sigmas = self.sigma_min * (sigma_scale) ** times
- g = sigmas * math.sqrt(
- (2 * math.log(sigma_scale))
- )
+ g = sigmas * math.sqrt((2 * math.log(sigma_scale)))
while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)
diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py
index ba40ec940..378eacfb1 100644
--- a/sbi/neural_nets/net_builders/score_nets.py
+++ b/sbi/neural_nets/net_builders/score_nets.py
@@ -3,11 +3,14 @@
import torch
import torch.nn as nn
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
- ImprovedVPScoreEstimator, SubVPScoreEstimator, VEScoreEstimator,
- VPScoreEstimator)
-from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
- z_standardization)
+ ConditionalScoreEstimator,
+ GaussianFourierTimeEmbedding,
+ ImprovedVPScoreEstimator,
+ SubVPScoreEstimator,
+ VEScoreEstimator,
+ VPScoreEstimator,
+)
+from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
from sbi.utils.user_input_checks import check_data_device
from torch import Tensor
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 6f55233be..987573614 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -9,7 +9,10 @@
import torch
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator, ImprovedVPScoreEstimator, VPScoreEstimator)
+ ConditionalScoreEstimator,
+ ImprovedVPScoreEstimator,
+ VPScoreEstimator,
+)
from sbi.neural_nets.net_builders import build_score_estimator
from scipy import stats
@@ -149,7 +152,6 @@ def _build_score_estimator_and_tensors(
def test_times_schedule():
-
id_net = torch.nn.Identity()
inpt_shape = (4,)
cond_shape = (4,)
@@ -163,14 +165,14 @@ def test_times_schedule():
obs = times.device
delta = ivpse.t_max - ivpse.t_min
- t_mu = ivpse.t_min + delta/2.
- t_std = delta/8.
+ t_mu = ivpse.t_min + delta / 2.0
+ t_std = delta / 8.0
assert exp == obs
assert times.shape == torch.Size((10,))
ndist = stats.norm(t_mu, t_std)
- lo,hi = ndist.ppf(.01), ndist.ppf(.99)
+ lo, hi = ndist.ppf(0.01), ndist.ppf(0.99)
assert times.max().item() <= hi
assert times.min().item() >= lo
@@ -179,7 +181,6 @@ def test_times_schedule():
def test_noise_schedule():
-
id_net = torch.nn.Identity()
inpt_shape = (4,)
cond_shape = (4,)
@@ -194,7 +195,7 @@ def test_noise_schedule():
assert noise.shape == torch.Size((10,))
ndist = stats.norm(ivpse.pmean, ivpse.pstd)
- lo,hi = ndist.ppf(.01), ndist.ppf(.99)
+ lo, hi = ndist.ppf(0.01), ndist.ppf(0.99)
assert noise.max().item() < hi
assert noise.min().item() > lo
From d2c0100bc45e89a87c11a3c0009306cd6d8a9b69 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 16:53:01 +0100
Subject: [PATCH 10/18] use the time schedule for computing the validation
scores
---
sbi/inference/trainers/npse/npse.py | 26 ++++++++++++++------------
1 file changed, 14 insertions(+), 12 deletions(-)
diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py
index c8b647e2e..ee5a236f4 100644
--- a/sbi/inference/trainers/npse/npse.py
+++ b/sbi/inference/trainers/npse/npse.py
@@ -5,17 +5,9 @@
from typing import Any, Callable, Optional, Union
import torch
-from torch import Tensor, ones
-from torch.distributions import Distribution
-from torch.nn.utils.clip_grad import clip_grad_norm_
-from torch.optim.adam import Adam
-from torch.utils.tensorboard.writer import SummaryWriter
-
from sbi import utils as utils
from sbi.inference import NeuralInference
-from sbi.inference.posteriors import (
- DirectPosterior,
-)
+from sbi.inference.posteriors import DirectPosterior
from sbi.inference.posteriors.score_posterior import ScorePosterior
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.factory import posterior_score_nn
@@ -29,6 +21,11 @@
)
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
from sbi.utils.torchutils import assert_all_finite
+from torch import Tensor, ones
+from torch.distributions import Distribution
+from torch.nn.utils.clip_grad import clip_grad_norm_
+from torch.optim.adam import Adam
+from torch.utils.tensorboard.writer import SummaryWriter
class NPSE(NeuralInference):
@@ -304,9 +301,14 @@ def default_calibration_kernel(x):
self._neural_net.to(self._device)
if isinstance(validation_times, int):
- validation_times = torch.linspace(
- self._neural_net.t_min, self._neural_net.t_max, validation_times
- )
+ if hasattr(self._neural_net, "times_schedule"):
+ validation_times = self._neural_net.times_schedule(validation_times)
+ else:
+ validation_times = torch.linspace(
+ self._neural_net.t_min,
+ self._neural_net.t_max,
+ steps=validation_times,
+ )
assert isinstance(
validation_times, Tensor
) # let pyright know validation_times is a Tensor.
From 5b506731e020026359efd1d0aef3a3d266896d80 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 16:53:44 +0100
Subject: [PATCH 11/18] propagate name change
---
sbi/neural_nets/net_builders/score_nets.py | 15 ++++++---------
1 file changed, 6 insertions(+), 9 deletions(-)
diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py
index 378eacfb1..e33499ab9 100644
--- a/sbi/neural_nets/net_builders/score_nets.py
+++ b/sbi/neural_nets/net_builders/score_nets.py
@@ -3,14 +3,11 @@
import torch
import torch.nn as nn
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- GaussianFourierTimeEmbedding,
- ImprovedVPScoreEstimator,
- SubVPScoreEstimator,
- VEScoreEstimator,
- VPScoreEstimator,
-)
-from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
+ ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
+ ImprovedScoreEstimator, SubVPScoreEstimator, VEScoreEstimator,
+ VPScoreEstimator)
+from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
+ z_standardization)
from sbi.utils.user_input_checks import check_data_device
from torch import Tensor
@@ -196,7 +193,7 @@ def build_score_estimator(
elif sde_type == "subvp":
estimator = SubVPScoreEstimator
elif sde_type == "vp++":
- estimator = ImprovedVPScoreEstimator
+ estimator = ImprovedScoreEstimator
else:
raise ValueError(f"SDE type: {sde_type} not supported.")
From 2081ae01f42a96467f8c95fc5641ee17026fa0e1 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 16:53:56 +0100
Subject: [PATCH 12/18] fix unit tests to respect new schedules
---
tests/score_estimator_test.py | 30 ++++++++----------------------
1 file changed, 8 insertions(+), 22 deletions(-)
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 987573614..7a698d66c 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -9,10 +9,7 @@
import torch
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- ImprovedVPScoreEstimator,
- VPScoreEstimator,
-)
+ ConditionalScoreEstimator, ImprovedScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator
from scipy import stats
@@ -159,25 +156,19 @@ def test_times_schedule():
with pytest.raises(NotImplementedError):
ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
- ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape)
+ ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape)
exp = ivpse.device
times = ivpse.times_schedule(10)
obs = times.device
- delta = ivpse.t_max - ivpse.t_min
- t_mu = ivpse.t_min + delta / 2.0
- t_std = delta / 8.0
-
assert exp == obs
assert times.shape == torch.Size((10,))
- ndist = stats.norm(t_mu, t_std)
- lo, hi = ndist.ppf(0.01), ndist.ppf(0.99)
+ assert times[0 ,...] != ivpse.t_min
+ assert times[-1,...] != ivpse.t_max
- assert times.max().item() <= hi
- assert times.min().item() >= lo
- assert times.max().item() < ivpse.t_max
- assert times.min().item() > ivpse.t_min
+ assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max]))
+ assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min]))
def test_noise_schedule():
@@ -185,7 +176,7 @@ def test_noise_schedule():
inpt_shape = (4,)
cond_shape = (4,)
- ivpse = ImprovedVPScoreEstimator(id_net, inpt_shape, cond_shape)
+ ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape)
exp = ivpse.device
times = ivpse.times_schedule(10)
noise = ivpse.noise_schedule(times)
@@ -193,9 +184,4 @@ def test_noise_schedule():
assert exp == obs
assert noise.shape == torch.Size((10,))
-
- ndist = stats.norm(ivpse.pmean, ivpse.pstd)
- lo, hi = ndist.ppf(0.01), ndist.ppf(0.99)
-
- assert noise.max().item() < hi
- assert noise.min().item() > lo
+ assert torch.allclose(times,noise)
From 4f58be6dfb5ddb1d831eb89e5edcaa17d7c18147 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 16:54:52 +0100
Subject: [PATCH 13/18] comply with formatting
---
sbi/inference/trainers/npse/npse.py | 11 +--
sbi/neural_nets/estimators/score_estimator.py | 79 +++++++++++--------
sbi/neural_nets/net_builders/score_nets.py | 16 ++--
tests/bm_test.py | 1 +
tests/score_estimator_test.py | 12 +--
5 files changed, 69 insertions(+), 50 deletions(-)
diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py
index ee5a236f4..d3c5d05a3 100644
--- a/sbi/inference/trainers/npse/npse.py
+++ b/sbi/inference/trainers/npse/npse.py
@@ -5,6 +5,12 @@
from typing import Any, Callable, Optional, Union
import torch
+from torch import Tensor, ones
+from torch.distributions import Distribution
+from torch.nn.utils.clip_grad import clip_grad_norm_
+from torch.optim.adam import Adam
+from torch.utils.tensorboard.writer import SummaryWriter
+
from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posteriors import DirectPosterior
@@ -21,11 +27,6 @@
)
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior
from sbi.utils.torchutils import assert_all_finite
-from torch import Tensor, ones
-from torch.distributions import Distribution
-from torch.nn.utils.clip_grad import clip_grad_norm_
-from torch.optim.adam import Adam
-from torch.utils.tensorboard.writer import SummaryWriter
class NPSE(NeuralInference):
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index 99803c04f..c712c8a06 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -6,10 +6,11 @@
from typing import Callable, Optional, Union
import torch
-from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
from scipy import stats
from torch import Tensor, nn
+from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
+
class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
r"""Score matching for score-based generative models (e.g., denoising diffusion).
@@ -134,7 +135,7 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
# Time dependent z-scoring! Keeps input at similar scales
input_enc = (input - mean) / std
- # Approximate score becoming exact for t -> t_max, "skip connection"
+ # Approximate score becoming exact for t -> t_max, "skip connection" (c_skip in edm)
score_gaussian = (input - mean) / std**2
# Score prediction by the network
@@ -329,13 +330,15 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
def noise_schedule(self, times: Tensor) -> Tensor:
"""
- Generate a beta schedule for mean scaling in variance-preserving
- stochastic differential equations (SDEs).
+ Generate a beta schedule in stochastic differential equations (SDEs).
+ This will be used for sampling.
This method acts as a fallback in case derivative classes do not
implement it on their own. It calculates a linear beta schedule defined
by the input `times`, which represent the normalized time steps t ∈ [0, 1].
+ We implement a linear noise schedule here.
+
Args:
times (Tensor):
SDE times in [0, 1]. This tensor will be regenerated from
@@ -351,9 +354,12 @@ def times_schedule(
self, num_samples: int, t_min: float = None, t_max: float = None
) -> Tensor:
"""
+ Construction time samples for evaluating the diffusion model.
Perform uniform sampling of time variables within the range [t_min, t_max].
The `times` tensor will be put on the same device as the stored network.
+ We implement a uniformly sampled time stepping here.
+
Args:
num_samples (int): Number of samples to generate.
t_min (float, optional): The minimum time value. Defaults to self.t_min.
@@ -396,9 +402,6 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]):
raise ValueError(f"Weight function {weight_fn} not recognized.")
-# TODO: experiment with time schedule with more samples around .5 (check edm paper)
-# TODO: impacts on training and evaluate
-# TODO: check effect in mini sbibm -> converges faster (focus on more important)
class VPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
@@ -497,11 +500,12 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
return g
-# TODO: experiment with time schedule with more samples around .5 (check edm paper)
-# TODO: impacts on training and evaluate
-# TODO: check effect in mini sbibm -> converges faster (focus on more important)
-class ImprovedVPScoreEstimator(ConditionalScoreEstimator):
- """Class for score estimators with variance preserving SDEs (i.e., DDPM)."""
+class ImprovedScoreEstimator(ConditionalScoreEstimator):
+ """Implement EDM-like score matching estimator as in [1]
+
+ [1] Karras et al "Elucidating the Design Space of Diffusion-Based
+ Generative Models", https://arxiv.org/abs/2206.00364
+ """
def __init__(
self,
@@ -509,19 +513,23 @@ def __init__(
input_shape: torch.Size,
condition_shape: torch.Size,
weight_fn: Union[str, Callable] = "max_likelihood",
- beta_min: float = 0.01,
- beta_max: float = 10.0,
+ beta_min: float = 0.002, # sigma_min in the paper
+ beta_max: float = 80.0, # sigma_max in the paper
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
- t_min: float = 1e-5,
- t_max: float = 1.0,
- pmean: float = -1.2,
- pstd: float = 1.2,
+ t_min: float = 1e-5, # will be ignored due to EDM setup
+ t_max: float = 1.0, #
+ pmean: float = -1.2, # mean of noise scheme for training
+ pstd: float = 1.2, # std of noise scheme for training
+ sigma_data: float = 0.5,
) -> None:
self.pmean, self.pstd = pmean, pstd
noise_dist = stats.norm(pmean, pstd**2)
- self.beta_min = exp(noise_dist.ppf(0.01))
- self.beta_max = exp(noise_dist.ppf(0.99))
+
+ self.sigma_min = exp(noise_dist.ppf(0.01))
+ self.sigma_max = exp(noise_dist.ppf(0.99))
+
+ self.rho = 7
super().__init__(
net,
@@ -537,9 +545,9 @@ def __init__(
)
def mean_t_fn(self, times: Tensor) -> Tensor:
- """Conditional mean function for variance preserving SDEs.
+ """Conditional mean function for EDM-style DMs.
Args:
- times: SDE time variable in [0,1].
+ times: time variable in [0,1].
Returns:
Conditional mean at a given time.
@@ -553,9 +561,9 @@ def mean_t_fn(self, times: Tensor) -> Tensor:
return phi
def std_fn(self, times: Tensor) -> Tensor:
- """Standard deviation function for variance preserving SDEs.
+ """Standard deviation function for EDM style DMs.
Args:
- times: SDE time variable in [0,1].
+ times: time variable in [0,1].
Returns:
Standard deviation at a given time.
@@ -616,15 +624,13 @@ def noise_schedule(self, times: Tensor) -> Tensor:
[1] Karras et al "Elucidating the Design Space of Diffusion-Based
Generative Models", https://arxiv.org/abs/2206.00364
"""
-
- samples = torch.randn_like(times) * (self.pstd) + self.pmean
- return torch.exp(samples)
+ return times
def times_schedule(
self, num_samples: int, t_min: float = None, t_max: float = None
) -> Tensor:
"""
- Perform normal sampling around the middle of the interval [t_min, t_max]
+ Construct time samples as suggested in EDM paper [1].
Args:
num_samples (int): Number of samples to generate.
@@ -635,14 +641,16 @@ def times_schedule(
Tensor: A tensor of sampled time variables scaled and shifted to
the range [0,1].
+ [1] Karras et al "Elucidating the Design Space of Diffusion-Based
+ Generative Models", https://arxiv.org/abs/2206.00364
"""
- t_min = self.t_min if isinstance(t_min, type(None)) else t_min
- t_max = self.t_max if isinstance(t_max, type(None)) else t_max
- t_mu = t_min + (t_max - t_min) / 2.0
- t_std = (t_max - t_min) / 8.0
+ times = torch.linspace(0.0, 1.0, steps=num_samples)
+ inv_rho = 1.0 / self.rho
- # apply scale and loc to normal distribution
- return torch.randn(num_samples, device=self.device) * t_std + t_mu
+ beta_scale = self.beta_max ** (inv_rho) - self.beta_min ** (inv_rho)
+ offset = self.beta_min ** (inv_rho)
+
+ return (offset + beta_scale * times) ** (self.rho)
class SubVPScoreEstimator(ConditionalScoreEstimator):
@@ -836,6 +844,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
return g
+# TODO: try to add a EDM-like estimator
+
+
class GaussianFourierTimeEmbedding(nn.Module):
"""Gaussian random features for encoding time steps.
diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py
index e33499ab9..e407e9e5a 100644
--- a/sbi/neural_nets/net_builders/score_nets.py
+++ b/sbi/neural_nets/net_builders/score_nets.py
@@ -2,14 +2,18 @@
import torch
import torch.nn as nn
+from torch import Tensor
+
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
- ImprovedScoreEstimator, SubVPScoreEstimator, VEScoreEstimator,
- VPScoreEstimator)
-from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
- z_standardization)
+ ConditionalScoreEstimator,
+ GaussianFourierTimeEmbedding,
+ ImprovedScoreEstimator,
+ SubVPScoreEstimator,
+ VEScoreEstimator,
+ VPScoreEstimator,
+)
+from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
from sbi.utils.user_input_checks import check_data_device
-from torch import Tensor
class EmbedInputs(nn.Module):
diff --git a/tests/bm_test.py b/tests/bm_test.py
index c61a6fef6..26cae3dc2 100644
--- a/tests/bm_test.py
+++ b/tests/bm_test.py
@@ -4,6 +4,7 @@
import pytest
import torch
from pytest_harvest import ResultsBag
+
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.trainers.npe import NPE_C
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 7a698d66c..2f5dfe88c 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -7,11 +7,13 @@
import pytest
import torch
+
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator, ImprovedScoreEstimator, VPScoreEstimator)
+ ConditionalScoreEstimator,
+ ImprovedScoreEstimator,
+)
from sbi.neural_nets.net_builders import build_score_estimator
-from scipy import stats
@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
@@ -164,8 +166,8 @@ def test_times_schedule():
assert exp == obs
assert times.shape == torch.Size((10,))
- assert times[0 ,...] != ivpse.t_min
- assert times[-1,...] != ivpse.t_max
+ assert times[0, ...] != ivpse.t_min
+ assert times[-1, ...] != ivpse.t_max
assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max]))
assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min]))
@@ -184,4 +186,4 @@ def test_noise_schedule():
assert exp == obs
assert noise.shape == torch.Size((10,))
- assert torch.allclose(times,noise)
+ assert torch.allclose(times, noise)
From 773a28bd2bb7a2ee73b3f05ce0e53b684d22b1c1 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 17:50:51 +0100
Subject: [PATCH 14/18] attempted to implement EDM-like diffusion
- without touching the forward function of ConditionalScoreEstimator
- benchmarks show that this leads to very long training time
without any performance improvements
---
sbi/neural_nets/estimators/score_estimator.py | 40 +++++++++----------
1 file changed, 19 insertions(+), 21 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index c712c8a06..8df7dbaed 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -6,11 +6,10 @@
from typing import Callable, Optional, Union
import torch
+from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
from scipy import stats
from torch import Tensor, nn
-from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
-
class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
r"""Score matching for score-based generative models (e.g., denoising diffusion).
@@ -129,13 +128,14 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
std = self.approx_marginal_std(time)
# As input to the neural net we want to have something that changes proportianl
- # to how the scores change
+ # to how the scores change (a la c_noise in edm)
time_enc = self.std_fn(time)
- # Time dependent z-scoring! Keeps input at similar scales
+ # Time dependent z-scoring! Keeps input at similar scales (c_in in edm)
input_enc = (input - mean) / std
- # Approximate score becoming exact for t -> t_max, "skip connection" (c_skip in edm)
+ # Approximate score becoming exact for t -> t_max, "skip connection"
+ # (a la c_skip in edm)
score_gaussian = (input - mean) / std**2
# Score prediction by the network
@@ -145,6 +145,7 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
# The learnable part will be largly scaled at the beginning of the diffusion
# and the gaussian part (where it should end up) will dominate at the end of
# the diffusion.
+ # (a la c_out in edm)
scale = self.mean_t_fn(time) / self.std_fn(time)
output_score = -scale * score_pred - score_gaussian
@@ -416,14 +417,8 @@ def __init__(
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
- t_max: float = 1.0,
- pmean: float = 1.2,
- pstd: float = -1.2,
+ t_max: float = 1.0
) -> None:
- self.pmean, self.pstd = pmean, pstd
- noise_dist = stats.norm(pmean, pstd**2)
- self.beta_min = exp(noise_dist.ppf(0.01))
- self.beta_max = exp(noise_dist.ppf(0.99))
super().__init__(
net,
@@ -515,6 +510,7 @@ def __init__(
weight_fn: Union[str, Callable] = "max_likelihood",
beta_min: float = 0.002, # sigma_min in the paper
beta_max: float = 80.0, # sigma_max in the paper
+ beta_data: float = .5, #sigma_data in the paper
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5, # will be ignored due to EDM setup
@@ -523,12 +519,15 @@ def __init__(
pstd: float = 1.2, # std of noise scheme for training
sigma_data: float = 0.5,
) -> None:
+
+
+ #TODO: store sigma values for training in extra field
self.pmean, self.pstd = pmean, pstd
noise_dist = stats.norm(pmean, pstd**2)
-
self.sigma_min = exp(noise_dist.ppf(0.01))
self.sigma_max = exp(noise_dist.ppf(0.99))
+ self.beta_data = beta_data #sigma data from edm paper
self.rho = 7
super().__init__(
@@ -546,34 +545,33 @@ def __init__(
def mean_t_fn(self, times: Tensor) -> Tensor:
"""Conditional mean function for EDM-style DMs.
+ This is required to model c_in.
+
Args:
times: time variable in [0,1].
Returns:
Conditional mean at a given time.
"""
- phi = torch.exp(
- -0.25 * times**2.0 * (self.beta_max - self.beta_min)
- - 0.5 * times * self.beta_min
- )
+ noise = self.noise_schedule(times)
+ phi = 1./torch.sqrt(noise**2 + self.beta_data**2)
for _ in range(len(self.input_shape)):
phi = phi.unsqueeze(-1)
return phi
def std_fn(self, times: Tensor) -> Tensor:
"""Standard deviation function for EDM style DMs.
+ This is akin to c_noise in the network/precond parametrisation.
Args:
times: time variable in [0,1].
Returns:
Standard deviation at a given time.
"""
- std = 1.0 - torch.exp(
- -0.5 * times**2.0 * (self.beta_max - self.beta_min) - times * self.beta_min
- )
+ std = .25*torch.log(times)
for _ in range(len(self.input_shape)):
std = std.unsqueeze(-1)
- return torch.sqrt(std)
+ return std
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance preserving SDEs.
From 6386fc8479d9a2f440d6a32e3659e1ba412f1298 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 18:17:00 +0100
Subject: [PATCH 15/18] removed "improved" denoising network
---
tests/bm_test.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/tests/bm_test.py b/tests/bm_test.py
index 26cae3dc2..28296cded 100644
--- a/tests/bm_test.py
+++ b/tests/bm_test.py
@@ -4,7 +4,6 @@
import pytest
import torch
from pytest_harvest import ResultsBag
-
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.trainers.npe import NPE_C
@@ -50,7 +49,7 @@
"npse": [
{"score_estimator": nn, "sde_type": sde}
for nn in SCORE_ESTIMATORS
- for sde in ["ve", "vp", "vp++"]
+ for sde in ["ve", "vp"]
],
"snpe": [{}],
"snle": [{}],
From 2bbd92fbf0787e069dcb2e460ebe342d83475186 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 18:17:28 +0100
Subject: [PATCH 16/18] consolidated tests
---
tests/score_estimator_test.py | 36 ++++++++++++++++-------------------
1 file changed, 16 insertions(+), 20 deletions(-)
diff --git a/tests/score_estimator_test.py b/tests/score_estimator_test.py
index 2f5dfe88c..b7afc5255 100644
--- a/tests/score_estimator_test.py
+++ b/tests/score_estimator_test.py
@@ -7,16 +7,13 @@
import pytest
import torch
-
from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- ImprovedScoreEstimator,
-)
+ ConditionalScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
@pytest.mark.parametrize("input_sample_dim", (1, 2))
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@@ -45,7 +42,7 @@ def test_score_estimator_loss_shapes(
@pytest.mark.gpu
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_score_estimator_on_device(sde_type, device):
"""Test whether DensityEstimators can be moved to the device."""
@@ -67,7 +64,7 @@ def test_score_estimator_on_device(sde_type, device):
assert str(loss.device).split(":")[0] == device, "Loss device mismatch."
-@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp", "vp++"])
+@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
@pytest.mark.parametrize("input_sample_dim", (1, 2))
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
@@ -158,19 +155,16 @@ def test_times_schedule():
with pytest.raises(NotImplementedError):
ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)
- ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape)
- exp = ivpse.device
- times = ivpse.times_schedule(10)
+ vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = vpse.device
+ times = vpse.times_schedule(10)
obs = times.device
assert exp == obs
assert times.shape == torch.Size((10,))
- assert times[0, ...] != ivpse.t_min
- assert times[-1, ...] != ivpse.t_max
-
- assert torch.allclose(times.max(), torch.Tensor([ivpse.beta_max]))
- assert torch.allclose(times.min(), torch.Tensor([ivpse.beta_min]))
+ assert times[0, ...] >= vpse.t_min
+ assert times[-1, ...] <= vpse.t_max
def test_noise_schedule():
@@ -178,12 +172,14 @@ def test_noise_schedule():
inpt_shape = (4,)
cond_shape = (4,)
- ivpse = ImprovedScoreEstimator(id_net, inpt_shape, cond_shape)
- exp = ivpse.device
- times = ivpse.times_schedule(10)
- noise = ivpse.noise_schedule(times)
+ vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
+ exp = vpse.device
+ times = vpse.times_schedule(10)
+ noise = vpse.noise_schedule(times)
obs = noise.device
assert exp == obs
assert noise.shape == torch.Size((10,))
- assert torch.allclose(times, noise)
+
+ assert noise[0, ...] >= vpse.beta_min
+ assert noise[-1, ...] <= vpse.beta_max
From 8f0c65a58c2df694dcdaaf540d37da3207f84b64 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 18:17:44 +0100
Subject: [PATCH 17/18] removed occurrances of vp++
---
sbi/neural_nets/net_builders/score_nets.py | 18 +++++-------------
1 file changed, 5 insertions(+), 13 deletions(-)
diff --git a/sbi/neural_nets/net_builders/score_nets.py b/sbi/neural_nets/net_builders/score_nets.py
index e407e9e5a..5780163c7 100644
--- a/sbi/neural_nets/net_builders/score_nets.py
+++ b/sbi/neural_nets/net_builders/score_nets.py
@@ -2,18 +2,13 @@
import torch
import torch.nn as nn
-from torch import Tensor
-
from sbi.neural_nets.estimators.score_estimator import (
- ConditionalScoreEstimator,
- GaussianFourierTimeEmbedding,
- ImprovedScoreEstimator,
- SubVPScoreEstimator,
- VEScoreEstimator,
- VPScoreEstimator,
-)
-from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
+ ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
+ SubVPScoreEstimator, VEScoreEstimator, VPScoreEstimator)
+from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
+ z_standardization)
from sbi.utils.user_input_checks import check_data_device
+from torch import Tensor
class EmbedInputs(nn.Module):
@@ -116,7 +111,6 @@ def build_score_estimator(
batch_y: Batch of ys, used to infer dimensionality and (optional) z-scoring.
sde_type: SDE type used, which defines the mean and std functions. One of:
- 'vp': Variance preserving.
- - 'vp++': Variance preserving.
- 'subvp': Sub-variance preserving.
- 've': Variance exploding.
Defaults to 'vp'.
@@ -196,8 +190,6 @@ def build_score_estimator(
estimator = VEScoreEstimator
elif sde_type == "subvp":
estimator = SubVPScoreEstimator
- elif sde_type == "vp++":
- estimator = ImprovedScoreEstimator
else:
raise ValueError(f"SDE type: {sde_type} not supported.")
From 833c17b595c14176f77946c4142409e3c8141fca Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Thu, 20 Mar 2025 18:19:20 +0100
Subject: [PATCH 18/18] removed all mentions of edm
---
sbi/neural_nets/estimators/score_estimator.py | 176 +-----------------
1 file changed, 10 insertions(+), 166 deletions(-)
diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py
index 8df7dbaed..acacc7d67 100644
--- a/sbi/neural_nets/estimators/score_estimator.py
+++ b/sbi/neural_nets/estimators/score_estimator.py
@@ -128,14 +128,13 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
std = self.approx_marginal_std(time)
# As input to the neural net we want to have something that changes proportianl
- # to how the scores change (a la c_noise in edm)
+ # to how the scores change
time_enc = self.std_fn(time)
- # Time dependent z-scoring! Keeps input at similar scales (c_in in edm)
+ # Time dependent z-scoring! Keeps input at similar scales
input_enc = (input - mean) / std
# Approximate score becoming exact for t -> t_max, "skip connection"
- # (a la c_skip in edm)
score_gaussian = (input - mean) / std**2
# Score prediction by the network
@@ -145,7 +144,6 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
# The learnable part will be largly scaled at the beginning of the diffusion
# and the gaussian part (where it should end up) will dominate at the end of
# the diffusion.
- # (a la c_out in edm)
scale = self.mean_t_fn(time) / self.std_fn(time)
output_score = -scale * score_pred - score_gaussian
@@ -182,7 +180,7 @@ def loss(
# update device if required
self.device = input.device if self.device != input.device else self.device
- # Sample times from the Markov chain
+ # Sample times from the Markov chain, use batch dimension
if times is None:
times = self.times_schedule(input.shape[0])
@@ -355,7 +353,7 @@ def times_schedule(
self, num_samples: int, t_min: float = None, t_max: float = None
) -> Tensor:
"""
- Construction time samples for evaluating the diffusion model.
+ Time samples for evaluating the diffusion model.
Perform uniform sampling of time variables within the range [t_min, t_max].
The `times` tensor will be put on the same device as the stored network.
@@ -374,7 +372,12 @@ def times_schedule(
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max
- return torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min
+ times = torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min
+
+ # t_min and t_max need to be part of the sequence
+ times[0,...] = t_min
+ times[-1,...] = t_max
+ return torch.Tensor(sorted(times))
def _set_weight_fn(self, weight_fn: Union[str, Callable]):
"""Set the weight function.
@@ -495,162 +498,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
return g
-class ImprovedScoreEstimator(ConditionalScoreEstimator):
- """Implement EDM-like score matching estimator as in [1]
-
- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
- Generative Models", https://arxiv.org/abs/2206.00364
- """
-
- def __init__(
- self,
- net: nn.Module,
- input_shape: torch.Size,
- condition_shape: torch.Size,
- weight_fn: Union[str, Callable] = "max_likelihood",
- beta_min: float = 0.002, # sigma_min in the paper
- beta_max: float = 80.0, # sigma_max in the paper
- beta_data: float = .5, #sigma_data in the paper
- mean_0: Union[Tensor, float] = 0.0,
- std_0: Union[Tensor, float] = 1.0,
- t_min: float = 1e-5, # will be ignored due to EDM setup
- t_max: float = 1.0, #
- pmean: float = -1.2, # mean of noise scheme for training
- pstd: float = 1.2, # std of noise scheme for training
- sigma_data: float = 0.5,
- ) -> None:
-
-
- #TODO: store sigma values for training in extra field
- self.pmean, self.pstd = pmean, pstd
- noise_dist = stats.norm(pmean, pstd**2)
- self.sigma_min = exp(noise_dist.ppf(0.01))
- self.sigma_max = exp(noise_dist.ppf(0.99))
-
- self.beta_data = beta_data #sigma data from edm paper
- self.rho = 7
-
- super().__init__(
- net,
- input_shape,
- condition_shape,
- mean_0=mean_0,
- std_0=std_0,
- weight_fn=weight_fn,
- beta_min=beta_min,
- beta_max=beta_max,
- t_min=t_min,
- t_max=t_max,
- )
-
- def mean_t_fn(self, times: Tensor) -> Tensor:
- """Conditional mean function for EDM-style DMs.
- This is required to model c_in.
-
- Args:
- times: time variable in [0,1].
-
- Returns:
- Conditional mean at a given time.
- """
- noise = self.noise_schedule(times)
- phi = 1./torch.sqrt(noise**2 + self.beta_data**2)
- for _ in range(len(self.input_shape)):
- phi = phi.unsqueeze(-1)
- return phi
-
- def std_fn(self, times: Tensor) -> Tensor:
- """Standard deviation function for EDM style DMs.
- This is akin to c_noise in the network/precond parametrisation.
- Args:
- times: time variable in [0,1].
-
- Returns:
- Standard deviation at a given time.
- """
- std = .25*torch.log(times)
- for _ in range(len(self.input_shape)):
- std = std.unsqueeze(-1)
- return std
-
- def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
- """Drift function for variance preserving SDEs.
-
- Args:
- input: Original data, x0.
- times: SDE time variable in [0,1].
-
- Returns:
- Drift function at a given time.
- """
- phi = -0.5 * self.noise_schedule(times)
- while len(phi.shape) < len(input.shape):
- phi = phi.unsqueeze(-1)
- return phi * input
-
- def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
- """Diffusion function for variance preserving SDEs.
-
- Args:
- input: Original data, x0.
- times: SDE time variable in [0,1].
-
- Returns:
- Drift function at a given time.
- """
- g = torch.sqrt(self.noise_schedule(times))
- while len(g.shape) < len(input.shape):
- g = g.unsqueeze(-1)
- return g
-
- def noise_schedule(self, times: Tensor) -> Tensor:
- """
- Generate a beta schedule similar to suggestions in the EDM [1] paper.
-
- This method acts as a fallback in case derivative classes do not
- implement it on their own. It calculates a linear beta schedule defined
- by the input `times`, which represent the normalized time steps t ∈ [0, 1].
-
- Args:
- times (Tensor):
- SDE times in [0, 1]. This tensor will be regenerated from
- self.times_schedule
-
- Returns:
- Tensor: Generated beta schedule at a given time.
-
- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
- Generative Models", https://arxiv.org/abs/2206.00364
- """
- return times
-
- def times_schedule(
- self, num_samples: int, t_min: float = None, t_max: float = None
- ) -> Tensor:
- """
- Construct time samples as suggested in EDM paper [1].
-
- Args:
- num_samples (int): Number of samples to generate.
- t_min (float, optional): The minimum time value. Defaults to self.t_min.
- t_max (float, optional): The maximum time value. Defaults to self.t_max.
-
- Returns:
- Tensor: A tensor of sampled time variables scaled and shifted to
- the range [0,1].
-
- [1] Karras et al "Elucidating the Design Space of Diffusion-Based
- Generative Models", https://arxiv.org/abs/2206.00364
- """
- times = torch.linspace(0.0, 1.0, steps=num_samples)
- inv_rho = 1.0 / self.rho
-
- beta_scale = self.beta_max ** (inv_rho) - self.beta_min ** (inv_rho)
- offset = self.beta_min ** (inv_rho)
-
- return (offset + beta_scale * times) ** (self.rho)
-
-
class SubVPScoreEstimator(ConditionalScoreEstimator):
"""Class for score estimators with sub-variance preserving SDEs."""
@@ -842,9 +689,6 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
return g
-# TODO: try to add a EDM-like estimator
-
-
class GaussianFourierTimeEmbedding(nn.Module):
"""Gaussian random features for encoding time steps.