|
1 | 1 | import torch
|
| 2 | +from torch.distributions import Distribution |
2 | 3 |
|
3 |
| -class torch.distributions.Distribution: |
4 |
| - ''' |
5 |
| - The abstract base class for probability distributions, which we inherit from. These methods are implied |
6 |
| - to be implemented for each subclass. |
7 |
| - ''' |
8 |
| - def __init__(batch_shape=torch.Size([]), event_shape=torch.Size([])): |
9 |
| - ''' |
10 |
| - Basic constructer of distribution. |
11 |
| - ''' |
12 |
| - |
13 |
| - @property |
14 |
| - def arg_constraints(): |
15 |
| - ''' |
16 |
| - Returns a dictionary from argument names to Constraint objects that should |
17 |
| - be satisfied by each argument of this distribution. Args that are not tensors need not appear |
18 |
| - in this dict. |
19 |
| - ''' |
20 |
| - |
21 |
| - def cdf(value): |
22 |
| - ''' |
23 |
| - Returns the cumulative density/mass function evaluated at value. |
24 |
| - ''' |
25 |
| - |
26 |
| - def entropy(): |
27 |
| - ''' |
28 |
| - Returns entropy of distribution, batched over batch_shape. |
29 |
| - ''' |
| 4 | +class Normal(Distribution): |
| 5 | + def __init__(self, loc, scale): |
| 6 | + self.loc = loc |
| 7 | + self.scale = scale |
30 | 8 |
|
31 |
| - def enumerate_support(expand=True): |
32 |
| - ''' |
33 |
| - Returns tensor containing all values supported by a discrete distribution. The result will |
34 |
| - enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape |
35 |
| - + event_shape (where event_shape = () for univariate distributions). |
36 |
| - ''' |
37 |
| - |
38 |
| - @property |
39 |
| - def mean(expand=True): |
40 |
| - ''' |
41 |
| - Returns mean of the distributio. |
42 |
| - ''' |
| 9 | + def transform(self, z): |
| 10 | + return (z - self.loc) / self.scale |
| 11 | + |
| 12 | + def d_transform_d_z(self): |
| 13 | + return 1 / self.scale |
43 | 14 |
|
44 |
| - @property |
45 |
| - def mode(expand=True): |
46 |
| - ''' |
47 |
| - Returns mean of the distributio. |
48 |
| - ''' |
49 |
| - def perplexity(): |
50 |
| - ''' |
51 |
| - Returns perplexity of distribution, batched over batch_shape. |
52 |
| - ''' |
53 |
| - |
54 |
| - def rsample(sample_shape=torch.Size([])): |
55 |
| - ''' |
56 |
| - Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution |
57 |
| - parameters are batched. |
58 |
| - ''' |
| 15 | + def sample(self): |
| 16 | + return torch.normal(self.loc, self.scale).detach() |
59 | 17 |
|
60 |
| - def sample(sample_shape=torch.Size([])): |
61 |
| - ''' |
62 |
| - Generates a sample_shape shaped sample or sample_shape shaped batch of reparameterized samples |
63 |
| - if the distribution parameters are batched. |
64 |
| - ''' |
| 18 | + def rsample(self): |
| 19 | + x = self.sample() |
65 | 20 |
|
66 |
| -class torch.distributions.implicit.Normal(Distribution): |
67 |
| - ''' |
68 |
| - A Gaussian distribution class with backpropagation capability for the rsample function through IRT. |
69 |
| - ''' |
70 |
| - def __init__(mean_matrix, covariance_matrix=None): |
71 |
| - pass |
| 21 | + transform = self.transform(x) |
72 | 22 |
|
73 |
| -class torch.distributions.implicit.Dirichlet(Distribution): |
74 |
| - ''' |
75 |
| - A Dirichlet distribution class with backpropagation capability for the rsample function through IRT. |
76 |
| - ''' |
77 |
| - def __init__(concentration, validate_args=None): |
78 |
| - pass |
79 |
| - |
80 |
| -class torch.distributions.implicit.Mixture(Distribution): |
81 |
| - ''' |
82 |
| - A Mixture of distributions class with backpropagation capability for the rsample function through IRT. |
83 |
| - ''' |
84 |
| - def __init__(distributions : List[Distribution]): |
85 |
| - pass |
| 23 | + surrogate_x = - transform / self.d_transform_d_z().detach() |
86 | 24 |
|
87 |
| -class torch.distributions.implicit.Student(Distribution): |
88 |
| - ''' |
89 |
| - A Student's distribution class with backpropagation capability for the rsample function through IRT. |
90 |
| - ''' |
91 |
| - def __init__(): |
92 |
| - pass |
93 |
| - |
94 |
| -class torch.distributions.implicit.Factorized(Distribution): |
95 |
| - ''' |
96 |
| - A class for an arbitrary factorized distribution with backpropagation capability for the rsample |
97 |
| - function through IRT. |
98 |
| - ''' |
| 25 | + # Replace gradients of x with gradients of surrogate_x, but keep the value. |
| 26 | + return x + (surrogate_x - surrogate_x.detach()) |
0 commit comments