Skip to content

Commit 8b4c63c

Browse files
committed
test commit
1 parent 217de5c commit 8b4c63c

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from torch.distributions import Distribution
3+
4+
### alternative infromation
5+
6+
7+
# Define a custom Normal distribution class that inherits from PyTorch's Distribution class
8+
class Normal(Distribution):
9+
# Indicates that the distribution supports reparameterized sampling
10+
has_rsample = True
11+
12+
def __init__(self, loc: torch.Tensor, scale: torch.Tensor, generator: torch.Generator = None) -> None:
13+
"""
14+
Initializes the Normal distribution with a given mean (loc) and standard deviation (scale).
15+
16+
Args:
17+
loc (Tensor): Mean of the normal distribution. This defines the central tendency of the distribution.
18+
scale (Tensor): Standard deviation of the normal distribution. This defines the spread or width of the distribution.
19+
generator (torch.Generator, optional): A random number generator for reproducible sampling.
20+
"""
21+
self.loc = loc # Mean of the distribution
22+
self.scale = scale # Standard deviation of the distribution
23+
self.generator = generator # Optional random number generator for reproducibility
24+
super(Distribution).__init__() # Initialize the base Distribution class
25+
26+
def transform(self, z: torch.Tensor) -> torch.Tensor:
27+
"""
28+
Transforms the input tensor `z` to the standard normal form using the distribution's mean and scale.
29+
30+
Args:
31+
z (Tensor): Input tensor to be transformed.
32+
33+
Returns:
34+
Tensor: The transformed tensor, which is normalized to have mean 0 and standard deviation 1.
35+
"""
36+
return (z - self.loc) / self.scale
37+
38+
def d_transform_d_z(self) -> torch.Tensor:
39+
"""
40+
Computes the derivative of the transform function with respect to the input tensor `z`.
41+
42+
Returns:
43+
Tensor: The derivative, which is the reciprocal of the scale. This is used for reparameterization.
44+
"""
45+
return 1 / self.scale
46+
47+
def sample(self) -> torch.Tensor:
48+
"""
49+
Generates a sample from the Normal distribution using PyTorch's `torch.normal` function.
50+
51+
Returns:
52+
Tensor: A tensor containing a sample from the distribution. The `detach()` method is used to prevent
53+
gradients from being tracked during sampling.
54+
"""
55+
return torch.normal(self.loc, self.scale, generator=self.generator).detach()
56+
57+
def rsample(self) -> torch.Tensor:
58+
"""
59+
Generates a reparameterized sample from the Normal distribution, which is useful for gradient-based optimization.
60+
61+
The `rsample` method generates a sample `x`, applies a transformation, and creates a surrogate sample
62+
that allows gradients to flow through the sampling process.
63+
64+
Returns:
65+
Tensor: A reparameterized sample tensor, which allows gradient backpropagation.
66+
"""
67+
x = self.sample() # Sample from the distribution
68+
69+
transform = self.transform(x) # Transform the sample to standard normal form
70+
surrogate_x = -transform / self.d_transform_d_z().detach() # Compute the surrogate for backpropagation
71+
# Return the sample adjusted to allow gradient flow
72+
return x + (surrogate_x - surrogate_x.detach())

code/irt/distributions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
from torch.distributions import Distribution
33

4+
### alternative infromation
5+
6+
47
# Define a custom Normal distribution class that inherits from PyTorch's Distribution class
58
class Normal(Distribution):
69
# Indicates that the distribution supports reparameterized sampling

0 commit comments

Comments
 (0)