Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sbi/inference/trainers/npse/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

from sbi import utils as utils
from sbi.inference import NeuralInference
from sbi.inference.posteriors import (
DirectPosterior,
)
from sbi.inference.posteriors import DirectPosterior
from sbi.inference.posteriors.score_posterior import ScorePosterior
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.factory import posterior_score_nn
Expand Down Expand Up @@ -304,9 +302,14 @@ def default_calibration_kernel(x):
self._neural_net.to(self._device)

if isinstance(validation_times, int):
validation_times = torch.linspace(
self._neural_net.t_min, self._neural_net.t_max, validation_times
)
if hasattr(self._neural_net, "times_schedule"):
validation_times = self._neural_net.times_schedule(validation_times)
else:
validation_times = torch.linspace(
self._neural_net.t_min,
self._neural_net.t_max,
steps=validation_times,
)
assert isinstance(
validation_times, Tensor
) # let pyright know validation_times is a Tensor.
Expand Down
146 changes: 90 additions & 56 deletions sbi/neural_nets/estimators/score_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import math
from math import pi
from math import exp, pi
from typing import Callable, Optional, Union

import torch
from torch import Tensor, nn

from sbi.neural_nets.estimators.base import ConditionalVectorFieldEstimator
from scipy import stats
from torch import Tensor, nn


class ConditionalScoreEstimator(ConditionalVectorFieldEstimator):
Expand Down Expand Up @@ -42,6 +42,8 @@ def __init__(
input_shape: torch.Size,
condition_shape: torch.Size,
weight_fn: Union[str, Callable] = "max_likelihood",
beta_min: float = 0.01,
beta_max: float = 10.0,
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-3,
Expand All @@ -59,14 +61,27 @@ def __init__(
- "identity": constant weights (1.),
- "max_likelihood": weights proportional to the diffusion function, or
- a custom function that returns a Callable.
beta_min: starting/minimal value for beta, the variance of the noise
beta_max: ending/maximal value for beta, the variance of the noise
mean_0: expected value of 0th step noise
scale_0: sigma of 0th step noise
t_min: smallest time step
t_max: largest time step

"""
super().__init__(net, input_shape, condition_shape)

# Set lambdas (variance weights) function.
self._set_weight_fn(weight_fn)

# Min time for diffusion (0 can be numerically unstable).
# Store device of this module
self.device = net.device if hasattr(net, "device") else torch.device("cpu")

# Min/max values for noise variance beta
self.beta_min = beta_min
self.beta_max = beta_max

# Min/max values for limits to time
self.t_min = t_min
self.t_max = t_max

Expand Down Expand Up @@ -152,6 +167,7 @@ def loss(
input: Input variable i.e. theta.
condition: Conditioning variable.
times: SDE time variable in [t_min, t_max]. Uniformly sampled if None.
if None, will be filled by calling the time_schedule method
control_variate: Whether to use a control variate to reduce the variance of
the stochastic loss estimator.
control_variate_threshold: Threshold for the control variate. If the std
Expand All @@ -161,13 +177,12 @@ def loss(
MSE between target score and network output, scaled by the weight function.

"""
# Sample diffusion times.
# update device if required
self.device = input.device if self.device != input.device else self.device

# Sample times from the Markov chain, use batch dimension
if times is None:
times = (
torch.rand(input.shape[0], device=input.device)
* (self.t_max - self.t_min)
+ self.t_min
)
times = self.times_schedule(input.shape[0])

# Sample noise.
eps = torch.randn_like(input)
Expand Down Expand Up @@ -312,6 +327,58 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
raise NotImplementedError

def noise_schedule(self, times: Tensor) -> Tensor:
"""
Generate a beta schedule in stochastic differential equations (SDEs).
This will be used for sampling.

This method acts as a fallback in case derivative classes do not
implement it on their own. It calculates a linear beta schedule defined
by the input `times`, which represent the normalized time steps t ∈ [0, 1].

We implement a linear noise schedule here.

Args:
times (Tensor):
SDE times in [0, 1]. This tensor will be regenerated from
self.times_schedule

Returns:
Tensor: Generated beta schedule at a given time.

"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def times_schedule(
self, num_samples: int, t_min: float = None, t_max: float = None
) -> Tensor:
"""
Time samples for evaluating the diffusion model.
Perform uniform sampling of time variables within the range [t_min, t_max].
The `times` tensor will be put on the same device as the stored network.

We implement a uniformly sampled time stepping here.

Args:
num_samples (int): Number of samples to generate.
t_min (float, optional): The minimum time value. Defaults to self.t_min.
t_max (float, optional): The maximum time value. Defaults to self.t_max.

Returns:
Tensor: A tensor of sampled time variables scaled and shifted to the
range [0,1].

"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max

times = torch.rand(num_samples, device=self.device) * (t_max - t_min) + t_min

# t_min and t_max need to be part of the sequence
times[0,...] = t_min
times[-1,...] = t_max
return torch.Tensor(sorted(times))

def _set_weight_fn(self, weight_fn: Union[str, Callable]):
"""Set the weight function.

Expand Down Expand Up @@ -353,17 +420,18 @@ def __init__(
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-5,
t_max: float = 1.0,
t_max: float = 1.0
) -> None:
self.beta_min = beta_min
self.beta_max = beta_max

super().__init__(
net,
input_shape,
condition_shape,
mean_0=mean_0,
std_0=std_0,
weight_fn=weight_fn,
beta_min=beta_min,
beta_max=beta_max,
t_min=t_min,
t_max=t_max,
)
Expand Down Expand Up @@ -399,17 +467,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return torch.sqrt(std)

def _beta_schedule(self, times: Tensor) -> Tensor:
"""Linear beta schedule for mean scaling in variance preserving SDEs.

Args:
times: SDE time variable in [0,1].

Returns:
Beta schedule at a given time.
"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance preserving SDEs.

Expand All @@ -420,7 +477,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
phi = -0.5 * self._beta_schedule(times)
phi = -0.5 * self.noise_schedule(times)
while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
return phi * input
Expand All @@ -435,7 +492,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
g = torch.sqrt(self._beta_schedule(times))
g = torch.sqrt(self.noise_schedule(times))
while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)
return g
Expand All @@ -457,13 +514,13 @@ def __init__(
t_min: float = 1e-2,
t_max: float = 1.0,
) -> None:
self.beta_min = beta_min
self.beta_max = beta_max
super().__init__(
net,
input_shape,
condition_shape,
weight_fn=weight_fn,
beta_min=beta_min,
beta_max=beta_max,
mean_0=mean_0,
std_0=std_0,
t_min=t_min,
Expand Down Expand Up @@ -501,18 +558,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std

def _beta_schedule(self, times: Tensor) -> Tensor:
"""Linear beta schedule for mean scaling in sub-variance preserving SDEs.
(Same as for variance preserving SDEs.)

Args:
times: SDE time variable in [0,1].

Returns:
Beta schedule at a given time.
"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for sub-variance preserving SDEs.

Expand All @@ -523,7 +568,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
phi = -0.5 * self._beta_schedule(times)
phi = -0.5 * self.noise_schedule(times)

while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
Expand All @@ -542,7 +587,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
g = torch.sqrt(
torch.abs(
self._beta_schedule(times)
self.noise_schedule(times)
* (
1
- torch.exp(
Expand Down Expand Up @@ -612,17 +657,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std

def _sigma_schedule(self, times: Tensor) -> Tensor:
"""Geometric sigma schedule for variance exploding SDEs.

Args:
times: SDE time variable in [0,1].

Returns:
Sigma schedule at a given time.
"""
return self.sigma_min * (self.sigma_max / self.sigma_min) ** times

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance exploding SDEs.

Expand All @@ -645,9 +679,9 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Diffusion function at a given time.
"""
g = self._sigma_schedule(times) * math.sqrt(
(2 * math.log(self.sigma_max / self.sigma_min))
)
sigma_scale = self.sigma_max / self.sigma_min
sigmas = self.sigma_min * (sigma_scale) ** times
g = sigmas * math.sqrt((2 * math.log(sigma_scale)))

while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)
Expand Down
14 changes: 5 additions & 9 deletions sbi/neural_nets/net_builders/score_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@

import torch
import torch.nn as nn
from torch import Tensor

from sbi.neural_nets.estimators.score_estimator import (
ConditionalScoreEstimator,
GaussianFourierTimeEmbedding,
SubVPScoreEstimator,
VEScoreEstimator,
VPScoreEstimator,
)
from sbi.utils.sbiutils import standardizing_net, z_score_parser, z_standardization
ConditionalScoreEstimator, GaussianFourierTimeEmbedding,
SubVPScoreEstimator, VEScoreEstimator, VPScoreEstimator)
from sbi.utils.sbiutils import (standardizing_net, z_score_parser,
z_standardization)
from sbi.utils.user_input_checks import check_data_device
from torch import Tensor


class EmbedInputs(nn.Module):
Expand Down
1 change: 0 additions & 1 deletion tests/bm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest
import torch
from pytest_harvest import ResultsBag

from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.trainers.npe import NPE_C
Expand Down
41 changes: 40 additions & 1 deletion tests/score_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import pytest
import torch

from sbi.neural_nets.embedding_nets import CNNEmbedding
from sbi.neural_nets.estimators.score_estimator import (
ConditionalScoreEstimator, VPScoreEstimator)
from sbi.neural_nets.net_builders import build_score_estimator


Expand Down Expand Up @@ -144,3 +145,41 @@ def _build_score_estimator_and_tensors(
)
condition = condition
return score_estimator, inputs, condition


def test_times_schedule():
id_net = torch.nn.Identity()
inpt_shape = (4,)
cond_shape = (4,)

with pytest.raises(NotImplementedError):
ConditionalScoreEstimator(id_net, inpt_shape, cond_shape)

vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
exp = vpse.device
times = vpse.times_schedule(10)
obs = times.device

assert exp == obs
assert times.shape == torch.Size((10,))

assert times[0, ...] >= vpse.t_min
assert times[-1, ...] <= vpse.t_max


def test_noise_schedule():
id_net = torch.nn.Identity()
inpt_shape = (4,)
cond_shape = (4,)

vpse = VPScoreEstimator(id_net, inpt_shape, cond_shape)
exp = vpse.device
times = vpse.times_schedule(10)
noise = vpse.noise_schedule(times)
obs = noise.device

assert exp == obs
assert noise.shape == torch.Size((10,))

assert noise[0, ...] >= vpse.beta_min
assert noise[-1, ...] <= vpse.beta_max
Loading