Skip to content

Commit c10adef

Browse files
Merge pull request #6 from intsystems/docs
Docs
2 parents b2fcdc3 + 1ebf711 commit c10adef

File tree

5 files changed

+90
-16
lines changed

5 files changed

+90
-16
lines changed

.flake8

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[flake8]
2+
profile = black
3+
count = True
4+
statistics = True
5+
extend-ignore =
6+
E203,
7+
F403,
8+
C901,
9+
N812,
10+
B010,
11+
ANN101,
12+
max-line-length = 120
13+
max-complexity = 15
14+
docstring-convention = google
15+
exclude =
16+
.ipynb,
17+
.ipynb_checkpoints,
18+
.git,
19+
__pycache__,
20+
venv,
21+
tests,
22+
ignore-names =
23+
X_train,
24+
X_control,
25+
X,
26+
X_val,
27+
X_valid,
28+
X_test,
29+
inline-quotes = double

.github/workflows/linters.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ jobs:
2727
pip install isort flake8 black
2828
2929
- name: Run isort
30-
run: isort .
30+
run: isort --check-only .
3131

3232
- name: Run flake8
3333
run: flake8 .
3434

3535
- name: Run black
36-
run: black . --check
36+
run: black --line-length=120 --check --verbose --diff --color .

.isort.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[isort]
2+
profile=black

doc/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
'sphinx.ext.autosummary', 'sphinx.ext.mathjax',
4040
'sphinx_rtd_theme']
4141

42-
autodoc_mock_imports = ["numpy", "scipy", "sklearn"]
42+
autodoc_mock_imports = ["numpy", "scipy", "sklearn", "torch"]
4343

4444
# Add any paths that contain templates here, relative to this directory.
4545
templates_path = ['_templates']

src/irt/distributions.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,69 @@
11
import torch
22
from torch.distributions import Distribution
33

4+
# Define a custom Normal distribution class that inherits from PyTorch's Distribution class
45
class Normal(Distribution):
5-
def __init__(self, loc, scale):
6-
self.loc = loc
7-
self.scale = scale
6+
# Indicates that the distribution supports reparameterized sampling
7+
has_rsample = True
88

9-
def transform(self, z):
9+
def __init__(self, loc: torch.Tensor, scale: torch.Tensor, generator: torch.Generator = None) -> None:
10+
"""
11+
Initializes the Normal distribution with a given mean (loc) and standard deviation (scale).
12+
13+
Args:
14+
loc (Tensor): Mean of the normal distribution. This defines the central tendency of the distribution.
15+
scale (Tensor): Standard deviation of the normal distribution. This defines the spread or width of the distribution.
16+
generator (torch.Generator, optional): A random number generator for reproducible sampling.
17+
"""
18+
self.loc = loc # Mean of the distribution
19+
self.scale = scale # Standard deviation of the distribution
20+
self.generator = generator # Optional random number generator for reproducibility
21+
super(Distribution).__init__() # Initialize the base Distribution class
22+
23+
def transform(self, z: torch.Tensor) -> torch.Tensor:
24+
"""
25+
Transforms the input tensor `z` to the standard normal form using the distribution's mean and scale.
26+
27+
Args:
28+
z (Tensor): Input tensor to be transformed.
29+
30+
Returns:
31+
Tensor: The transformed tensor, which is normalized to have mean 0 and standard deviation 1.
32+
"""
1033
return (z - self.loc) / self.scale
11-
12-
def d_transform_d_z(self):
34+
35+
def d_transform_d_z(self) -> torch.Tensor:
36+
"""
37+
Computes the derivative of the transform function with respect to the input tensor `z`.
38+
39+
Returns:
40+
Tensor: The derivative, which is the reciprocal of the scale. This is used for reparameterization.
41+
"""
1342
return 1 / self.scale
1443

15-
def sample(self):
16-
return torch.normal(self.loc, self.scale).detach()
44+
def sample(self) -> torch.Tensor:
45+
"""
46+
Generates a sample from the Normal distribution using PyTorch's `torch.normal` function.
1747
18-
def rsample(self):
19-
x = self.sample()
48+
Returns:
49+
Tensor: A tensor containing a sample from the distribution. The `detach()` method is used to prevent
50+
gradients from being tracked during sampling.
51+
"""
52+
return torch.normal(self.loc, self.scale, generator=self.generator).detach()
2053

21-
transform = self.transform(x)
54+
def rsample(self) -> torch.Tensor:
55+
"""
56+
Generates a reparameterized sample from the Normal distribution, which is useful for gradient-based optimization.
2257
23-
surrogate_x = - transform / self.d_transform_d_z().detach()
58+
The `rsample` method generates a sample `x`, applies a transformation, and creates a surrogate sample
59+
that allows gradients to flow through the sampling process.
2460
25-
# Replace gradients of x with gradients of surrogate_x, but keep the value.
61+
Returns:
62+
Tensor: A reparameterized sample tensor, which allows gradient backpropagation.
63+
"""
64+
x = self.sample() # Sample from the distribution
65+
66+
transform = self.transform(x) # Transform the sample to standard normal form
67+
surrogate_x = -transform / self.d_transform_d_z().detach() # Compute the surrogate for backpropagation
68+
# Return the sample adjusted to allow gradient flow
2669
return x + (surrogate_x - surrogate_x.detach())

0 commit comments

Comments
 (0)