Skip to content

Commit 2b4f6e8

Browse files
Update distributions.py
1 parent b2fcdc3 commit 2b4f6e8

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

src/irt/distributions.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,69 @@
11
import torch
22
from torch.distributions import Distribution
33

4+
# Define a custom Normal distribution class that inherits from PyTorch's Distribution class
45
class Normal(Distribution):
5-
def __init__(self, loc, scale):
6-
self.loc = loc
7-
self.scale = scale
6+
# Indicates that the distribution supports reparameterized sampling
7+
has_rsample = True
88

9-
def transform(self, z):
9+
def __init__(self, loc: torch.Tensor, scale: torch.Tensor, generator: torch.Generator = None) -> None:
10+
"""
11+
Initializes the Normal distribution with a given mean (loc) and standard deviation (scale).
12+
13+
Args:
14+
loc (Tensor): Mean of the normal distribution. This defines the central tendency of the distribution.
15+
scale (Tensor): Standard deviation of the normal distribution. This defines the spread or width of the distribution.
16+
generator (torch.Generator, optional): A random number generator for reproducible sampling.
17+
"""
18+
self.loc = loc # Mean of the distribution
19+
self.scale = scale # Standard deviation of the distribution
20+
self.generator = generator # Optional random number generator for reproducibility
21+
super(Distribution).__init__() # Initialize the base Distribution class
22+
23+
def transform(self, z: torch.Tensor) -> torch.Tensor:
24+
"""
25+
Transforms the input tensor `z` to the standard normal form using the distribution's mean and scale.
26+
27+
Args:
28+
z (Tensor): Input tensor to be transformed.
29+
30+
Returns:
31+
Tensor: The transformed tensor, which is normalized to have mean 0 and standard deviation 1.
32+
"""
1033
return (z - self.loc) / self.scale
11-
12-
def d_transform_d_z(self):
34+
35+
def d_transform_d_z(self) -> torch.Tensor:
36+
"""
37+
Computes the derivative of the transform function with respect to the input tensor `z`.
38+
39+
Returns:
40+
Tensor: The derivative, which is the reciprocal of the scale. This is used for reparameterization.
41+
"""
1342
return 1 / self.scale
1443

15-
def sample(self):
16-
return torch.normal(self.loc, self.scale).detach()
44+
def sample(self) -> torch.Tensor:
45+
"""
46+
Generates a sample from the Normal distribution using PyTorch's `torch.normal` function.
1747
18-
def rsample(self):
19-
x = self.sample()
48+
Returns:
49+
Tensor: A tensor containing a sample from the distribution. The `detach()` method is used to prevent
50+
gradients from being tracked during sampling.
51+
"""
52+
return torch.normal(self.loc, self.scale, generator=self.generator).detach()
2053

21-
transform = self.transform(x)
54+
def rsample(self) -> torch.Tensor:
55+
"""
56+
Generates a reparameterized sample from the Normal distribution, which is useful for gradient-based optimization.
2257
23-
surrogate_x = - transform / self.d_transform_d_z().detach()
58+
The `rsample` method generates a sample `x`, applies a transformation, and creates a surrogate sample
59+
that allows gradients to flow through the sampling process.
2460
25-
# Replace gradients of x with gradients of surrogate_x, but keep the value.
61+
Returns:
62+
Tensor: A reparameterized sample tensor, which allows gradient backpropagation.
63+
"""
64+
x = self.sample() # Sample from the distribution
65+
66+
transform = self.transform(x) # Transform the sample to standard normal form
67+
surrogate_x = -transform / self.d_transform_d_z().detach() # Compute the surrogate for backpropagation
68+
# Return the sample adjusted to allow gradient flow
2669
return x + (surrogate_x - surrogate_x.detach())

0 commit comments

Comments
 (0)