1
+ import unittest
2
+ import math
3
+ import torch
4
+ from torch .distributions import constraints
5
+ from torch .distributions .exp_family import ExponentialFamily
6
+ from numbers import Number
7
+ from irt .distributions import Normal , Gamma
8
+
9
+
10
+ class TestNormalDistribution (unittest .TestCase ):
11
+ def setUp (self ):
12
+ self .loc = torch .tensor ([0.0 , 1.0 ]).requires_grad_ (True )
13
+ self .scale = torch .tensor ([1.0 , 2.0 ]).requires_grad_ (True )
14
+ self .normal = Normal (self .loc , self .scale )
15
+
16
+ def test_init (self ):
17
+ normal = Normal (0.0 , 1.0 )
18
+ self .assertEqual (normal .loc , 0.0 )
19
+ self .assertEqual (normal .scale , 1.0 )
20
+ self .assertEqual (normal .batch_shape , torch .Size ())
21
+
22
+ normal = Normal (torch .tensor ([0.0 , 1.0 ]), torch .tensor ([1.0 , 2.0 ]))
23
+ self .assertTrue (torch .equal (normal .loc , torch .tensor ([0.0 , 1.0 ])))
24
+ self .assertTrue (torch .equal (normal .scale , torch .tensor ([1.0 , 2.0 ])))
25
+ self .assertEqual (normal .batch_shape , torch .Size ([2 ]))
26
+
27
+ def test_properties (self ):
28
+ self .assertTrue (torch .equal (self .normal .mean , self .loc ))
29
+ self .assertTrue (torch .equal (self .normal .mode , self .loc ))
30
+ self .assertTrue (torch .equal (self .normal .stddev , self .scale ))
31
+ self .assertTrue (torch .equal (self .normal .variance , self .scale ** 2 ))
32
+
33
+ def test_entropy (self ):
34
+ entropy = self .normal .entropy ()
35
+ expected_entropy = 0.5 + 0.5 * math .log (2 * math .pi ) + torch .log (self .scale )
36
+ self .assertTrue (torch .allclose (entropy , expected_entropy ))
37
+
38
+ def test_cdf (self ):
39
+ value = torch .tensor ([0.0 , 2.0 ])
40
+ cdf = self .normal .cdf (value )
41
+ expected_cdf = 0.5 * (1 + torch .erf ((value - self .loc ) / (self .scale * math .sqrt (2 ))))
42
+ self .assertTrue (torch .allclose (cdf , expected_cdf ))
43
+
44
+ def test_expand (self ):
45
+ expanded_normal = self .normal .expand (torch .Size ([3 , 2 ]))
46
+ self .assertEqual (expanded_normal .batch_shape , torch .Size ([3 , 2 ]))
47
+ self .assertTrue (torch .equal (expanded_normal .loc , self .loc .expand ([3 , 2 ])))
48
+ self .assertTrue (torch .equal (expanded_normal .scale , self .scale .expand ([3 , 2 ])))
49
+
50
+ def test_icdf (self ):
51
+ value = torch .tensor ([0.2 , 0.8 ])
52
+ icdf = self .normal .icdf (value )
53
+ expected_icdf = self .loc + self .scale * torch .erfinv (2 * value - 1 ) * math .sqrt (2 )
54
+ self .assertTrue (torch .allclose (icdf , expected_icdf ))
55
+
56
+ def test_log_prob (self ):
57
+ value = torch .tensor ([0.0 , 2.0 ])
58
+ log_prob = self .normal .log_prob (value )
59
+ var = self .scale ** 2
60
+ log_scale = self .scale .log ()
61
+ expected_log_prob = - ((value - self .loc ) ** 2 ) / (2 * var ) - log_scale - math .log (math .sqrt (2 * math .pi ))
62
+ self .assertTrue (torch .allclose (log_prob , expected_log_prob ))
63
+
64
+ def test_sample (self ):
65
+ samples = self .normal .sample (sample_shape = torch .Size ([100 ]))
66
+ self .assertEqual (samples .shape , torch .Size ([100 , 2 ])) # Check shape
67
+ emperic_mean = samples .mean (dim = 0 )
68
+ self .assertTrue ((emperic_mean < self .normal .mean + self .normal .scale ).all ())
69
+ self .assertTrue ((self .normal .mean - self .normal .scale < emperic_mean ).all ())
70
+
71
+ def test_rsample (self ):
72
+ samples = self .normal .rsample (sample_shape = torch .Size ([10 ]))
73
+ self .assertEqual (samples .shape , torch .Size ([10 , 2 ])) # Check shape
74
+ self .assertTrue (samples .requires_grad ) # Check gradient tracking
75
+
76
+
77
+ class TestGammaDistribution (unittest .TestCase ):
78
+ def setUp (self ):
79
+ self .concentration = torch .tensor ([1.0 , 2.0 ]).requires_grad_ (True )
80
+ self .rate = torch .tensor ([1.0 , 0.5 ]).requires_grad_ (True )
81
+ self .gamma = Gamma (self .concentration , self .rate )
82
+
83
+ def test_init (self ):
84
+ gamma = Gamma (1.0 , 1.0 )
85
+ self .assertEqual (gamma .concentration , 1.0 )
86
+ self .assertEqual (gamma .rate , 1.0 )
87
+ self .assertEqual (gamma .batch_shape , torch .Size ())
88
+
89
+ gamma = Gamma (torch .tensor ([1.0 , 2.0 ]), torch .tensor ([1.0 , 0.5 ]))
90
+ self .assertTrue (torch .equal (gamma .concentration , torch .tensor ([1.0 , 2.0 ])))
91
+ self .assertTrue (torch .equal (gamma .rate , torch .tensor ([1.0 , 0.5 ])))
92
+ self .assertEqual (gamma .batch_shape , torch .Size ([2 ]))
93
+
94
+ def test_properties (self ):
95
+ self .assertTrue (torch .allclose (self .gamma .mean , self .concentration / self .rate ))
96
+ self .assertTrue (torch .allclose (self .gamma .mode , ((self .concentration - 1 ) / self .rate ).clamp (min = 0 )))
97
+ self .assertTrue (torch .allclose (self .gamma .variance , self .concentration / self .rate .pow (2 )))
98
+
99
+ def test_expand (self ):
100
+ expanded_gamma = self .gamma .expand (torch .Size ([3 , 2 ]))
101
+ self .assertEqual (expanded_gamma .batch_shape , torch .Size ([3 , 2 ]))
102
+ self .assertTrue (torch .equal (expanded_gamma .concentration , self .concentration .expand ([3 , 2 ])))
103
+ self .assertTrue (torch .equal (expanded_gamma .rate , self .rate .expand ([3 , 2 ])))
104
+
105
+ def test_rsample (self ):
106
+ samples = self .gamma .rsample (sample_shape = torch .Size ([10 ]))
107
+ self .assertEqual (samples .shape , torch .Size ([10 , 2 ])) # Check shape
108
+ self .assertTrue (samples .requires_grad ) #Check gradient tracking
109
+
110
+
111
+ def test_log_prob (self ):
112
+ value = torch .tensor ([1.0 , 2.0 ])
113
+ log_prob = self .gamma .log_prob (value )
114
+ expected_log_prob = (
115
+ torch .xlogy (self .concentration , self .rate )
116
+ + torch .xlogy (self .concentration - 1 , value )
117
+ - self .rate * value
118
+ - torch .lgamma (self .concentration )
119
+ )
120
+ self .assertTrue (torch .allclose (log_prob , expected_log_prob ))
121
+
122
+ def test_entropy (self ):
123
+ entropy = self .gamma .entropy ()
124
+ expected_entropy = (
125
+ self .concentration
126
+ - torch .log (self .rate )
127
+ + torch .lgamma (self .concentration )
128
+ + (1.0 - self .concentration ) * torch .digamma (self .concentration )
129
+ )
130
+ self .assertTrue (torch .allclose (entropy , expected_entropy ))
131
+
132
+ def test_natural_params (self ):
133
+ natural_params = self .gamma ._natural_params
134
+ expected_natural_params = (self .concentration - 1 , - self .rate )
135
+ self .assertTrue (torch .equal (natural_params [0 ], expected_natural_params [0 ]))
136
+ self .assertTrue (torch .equal (natural_params [1 ], expected_natural_params [1 ]))
137
+
138
+ def test_log_normalizer (self ):
139
+ x , y = self .gamma ._natural_params
140
+ log_normalizer = self .gamma ._log_normalizer (x , y )
141
+ expected_log_normalizer = torch .lgamma (x + 1 ) + (x + 1 ) * torch .log (- y .reciprocal ())
142
+ self .assertTrue (torch .allclose (log_normalizer , expected_log_normalizer ))
143
+
144
+ def test_cdf (self ):
145
+ value = torch .tensor ([1.0 , 2.0 ])
146
+ cdf = self .gamma .cdf (value )
147
+ expected_cdf = torch .special .gammainc (self .concentration , self .rate * value )
148
+ self .assertTrue (torch .allclose (cdf , expected_cdf ))
149
+
150
+
151
+ def test_invalid_inputs (self ):
152
+ with self .assertRaises (ValueError ):
153
+ Gamma (torch .tensor ([- 1.0 , 1.0 ]), self .rate ) # Negative concentration
154
+ with self .assertRaises (ValueError ):
155
+ Gamma (self .concentration , torch .tensor ([- 1.0 , 1.0 ])) # Negative rate
156
+ with self .assertRaises (ValueError ):
157
+ self .gamma .log_prob (torch .tensor ([- 1.0 , 1.0 ])) # Negative value
158
+
159
+ if __name__ == "__main__" :
160
+ unittest .main ()
0 commit comments