Skip to content

Commit 9b01ea1

Browse files
Update distributions.py
1 parent 72b4645 commit 9b01ea1

File tree

1 file changed

+18
-90
lines changed

1 file changed

+18
-90
lines changed

src/irt/distributions.py

Lines changed: 18 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,26 @@
11
import torch
2+
from torch.distributions import Distribution
23

3-
class torch.distributions.Distribution:
4-
'''
5-
The abstract base class for probability distributions, which we inherit from. These methods are implied
6-
to be implemented for each subclass.
7-
'''
8-
def __init__(batch_shape=torch.Size([]), event_shape=torch.Size([])):
9-
'''
10-
Basic constructer of distribution.
11-
'''
12-
13-
@property
14-
def arg_constraints():
15-
'''
16-
Returns a dictionary from argument names to Constraint objects that should
17-
be satisfied by each argument of this distribution. Args that are not tensors need not appear
18-
in this dict.
19-
'''
20-
21-
def cdf(value):
22-
'''
23-
Returns the cumulative density/mass function evaluated at value.
24-
'''
25-
26-
def entropy():
27-
'''
28-
Returns entropy of distribution, batched over batch_shape.
29-
'''
4+
class Normal(Distribution):
5+
def __init__(self, loc, scale):
6+
self.loc = loc
7+
self.scale = scale
308

31-
def enumerate_support(expand=True):
32-
'''
33-
Returns tensor containing all values supported by a discrete distribution. The result will
34-
enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape
35-
+ event_shape (where event_shape = () for univariate distributions).
36-
'''
37-
38-
@property
39-
def mean(expand=True):
40-
'''
41-
Returns mean of the distributio.
42-
'''
9+
def transform(self, z):
10+
return (z - self.loc) / self.scale
11+
12+
def d_transform_d_z(self):
13+
return 1 / self.scale
4314

44-
@property
45-
def mode(expand=True):
46-
'''
47-
Returns mean of the distributio.
48-
'''
49-
def perplexity():
50-
'''
51-
Returns perplexity of distribution, batched over batch_shape.
52-
'''
53-
54-
def rsample(sample_shape=torch.Size([])):
55-
'''
56-
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution
57-
parameters are batched.
58-
'''
15+
def sample(self):
16+
return torch.normal(self.loc, self.scale).detach()
5917

60-
def sample(sample_shape=torch.Size([])):
61-
'''
62-
Generates a sample_shape shaped sample or sample_shape shaped batch of reparameterized samples
63-
if the distribution parameters are batched.
64-
'''
18+
def rsample(self):
19+
x = self.sample()
6520

66-
class torch.distributions.implicit.Normal(Distribution):
67-
'''
68-
A Gaussian distribution class with backpropagation capability for the rsample function through IRT.
69-
'''
70-
def __init__(mean_matrix, covariance_matrix=None):
71-
pass
21+
transform = self.transform(x)
7222

73-
class torch.distributions.implicit.Dirichlet(Distribution):
74-
'''
75-
A Dirichlet distribution class with backpropagation capability for the rsample function through IRT.
76-
'''
77-
def __init__(concentration, validate_args=None):
78-
pass
79-
80-
class torch.distributions.implicit.Mixture(Distribution):
81-
'''
82-
A Mixture of distributions class with backpropagation capability for the rsample function through IRT.
83-
'''
84-
def __init__(distributions : List[Distribution]):
85-
pass
23+
surrogate_x = - transform / self.d_transform_d_z().detach()
8624

87-
class torch.distributions.implicit.Student(Distribution):
88-
'''
89-
A Student's distribution class with backpropagation capability for the rsample function through IRT.
90-
'''
91-
def __init__():
92-
pass
93-
94-
class torch.distributions.implicit.Factorized(Distribution):
95-
'''
96-
A class for an arbitrary factorized distribution with backpropagation capability for the rsample
97-
function through IRT.
98-
'''
25+
# Replace gradients of x with gradients of surrogate_x, but keep the value.
26+
return x + (surrogate_x - surrogate_x.detach())

0 commit comments

Comments
 (0)