Skip to content

Commit 06dd3f8

Browse files
Create distributions.py
1 parent c662aca commit 06dd3f8

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

src/irt/distributions.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
class torch.distributions.Distribution:
2+
'''
3+
The abstract base class for probability distributions, which we inherit from. These methods are implied
4+
to be implemented for each subclass.
5+
'''
6+
def __init__(batch_shape=torch.Size([]), event_shape=torch.Size([])):
7+
'''
8+
Basic constructer of distribution.
9+
'''
10+
11+
@property
12+
def arg_constraints():
13+
'''
14+
Returns a dictionary from argument names to Constraint objects that should
15+
be satisfied by each argument of this distribution. Args that are not tensors need not appear
16+
in this dict.
17+
'''
18+
19+
def cdf(value):
20+
'''
21+
Returns the cumulative density/mass function evaluated at value.
22+
'''
23+
24+
def entropy():
25+
'''
26+
Returns entropy of distribution, batched over batch_shape.
27+
'''
28+
29+
def enumerate_support(expand=True):
30+
'''
31+
Returns tensor containing all values supported by a discrete distribution. The result will
32+
enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape
33+
+ event_shape (where event_shape = () for univariate distributions).
34+
'''
35+
36+
@property
37+
def mean(expand=True):
38+
'''
39+
Returns mean of the distributio.
40+
'''
41+
42+
@property
43+
def mode(expand=True):
44+
'''
45+
Returns mean of the distributio.
46+
'''
47+
def perplexity():
48+
'''
49+
Returns perplexity of distribution, batched over batch_shape.
50+
'''
51+
52+
def rsample(sample_shape=torch.Size([])):
53+
'''
54+
Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution
55+
parameters are batched.
56+
'''
57+
58+
def sample(sample_shape=torch.Size([])):
59+
'''
60+
Generates a sample_shape shaped sample or sample_shape shaped batch of reparameterized samples
61+
if the distribution parameters are batched.
62+
'''
63+
64+
class torch.distributions.implicit.Normal(Distribution):
65+
'''
66+
A Gaussian distribution class with backpropagation capability for the rsample function through IRT.
67+
'''
68+
def __init__(mean_matrix, covariance_matrix=None):
69+
pass
70+
71+
class torch.distributions.implicit.Dirichlet(Distribution):
72+
'''
73+
A Dirichlet distribution class with backpropagation capability for the rsample function through IRT.
74+
'''
75+
def __init__(concentration, validate_args=None):
76+
pass
77+
78+
class torch.distributions.implicit.Mixture(Distribution):
79+
'''
80+
A Mixture of distributions class with backpropagation capability for the rsample function through IRT.
81+
'''
82+
def __init__(distributions : List[Distribution]):
83+
pass
84+
85+
class torch.distributions.implicit.Student(Distribution):
86+
'''
87+
A Student's distribution class with backpropagation capability for the rsample function through IRT.
88+
'''
89+
def __init__():
90+
pass
91+
92+
class torch.distributions.implicit.Factorized(Distribution):
93+
'''
94+
A class for an arbitrary factorized distribution with backpropagation capability for the rsample
95+
function through IRT.
96+
'''

0 commit comments

Comments
 (0)