1
1
import unittest
2
2
import math
3
3
import torch
4
- from irt .distributions import Normal , Gamma , MixtureSameFamily , Beta , Dirichlet
4
+ import sys
5
+ sys .path .append ('../src' )
6
+ from irt .distributions import Normal , Gamma , MixtureSameFamily , Beta , Dirichlet , StudentT
5
7
from torch .distributions import Categorical , Independent
6
8
7
9
@@ -350,5 +352,93 @@ def test_invalid_inputs(self):
350
352
self .dirichlet .log_prob (torch .tensor ([0.2 , 0.3 , 0.6 ])) #Values don't sum to 1
351
353
352
354
355
+ class TestStudentT (unittest .TestCase ):
356
+
357
+ def setUp (self ):
358
+ self .df = torch .tensor ([3.0 , 5.0 ]).requires_grad_ (True )
359
+ self .loc = torch .tensor ([1.0 , 2.0 ]).requires_grad_ (True )
360
+ self .scale = torch .tensor ([0.5 , 1.0 ]).requires_grad_ (True )
361
+ self .studentt = StudentT (self .df , self .loc , self .scale , validate_args = True )
362
+
363
+ def test_init (self ):
364
+ studentt = StudentT (3.0 , 1.0 , 0.5 )
365
+ self .assertEqual (studentt .df , 3.0 )
366
+ self .assertEqual (studentt .loc , 1.0 )
367
+ self .assertEqual (studentt .scale , 0.5 )
368
+ self .assertEqual (studentt .gamma .concentration , 1.5 ) #Check Gamma initialization
369
+ self .assertEqual (studentt .gamma .rate , 1.5 )
370
+
371
+ def test_properties (self ):
372
+ df = torch .tensor ([.3 , 2.0 ])
373
+ loc = torch .tensor ([1.0 , 2.0 ])
374
+ scale = torch .tensor ([0.5 , 1.0 ])
375
+ studentt = StudentT (df , loc , scale )
376
+ self .assertTrue (torch .equal (studentt .mode , studentt .loc ))
377
+ # Check mean (undefined for df <= 1)
378
+ # print(self.studentt.mean[0])
379
+ self .assertTrue (torch .isnan (studentt .mean [0 ])) #Testing for nan values
380
+ self .assertTrue (torch .allclose (studentt .mean [1 ], studentt .loc [1 ])) #Mean should be defined for df > 1
381
+
382
+ # Check variance (undefined for df <= 1, infinite for 1 < df <= 2)
383
+ self .assertTrue (torch .isnan (studentt .variance [0 ]))
384
+ self .assertTrue (torch .isinf (studentt .variance [1 ])) # Should be inf for 1 < df <=2
385
+ self .assertTrue (torch .allclose (studentt .variance [1 ], (scale [1 ].pow (2 ) * df [1 ] / (df [1 ] - 2 )))) #Should be defined for df > 2
386
+
387
+
388
+ def test_expand (self ):
389
+ expanded_studentt = self .studentt .expand (torch .Size ([2 , 2 ]))
390
+ self .assertEqual (expanded_studentt .batch_shape , torch .Size ([2 , 2 ]))
391
+ self .assertTrue (torch .equal (expanded_studentt .df , self .df .expand ([2 , 2 ])))
392
+ self .assertTrue (torch .equal (expanded_studentt .loc , self .loc .expand ([2 , 2 ])))
393
+ self .assertTrue (torch .equal (expanded_studentt .scale , self .scale .expand ([2 , 2 ])))
394
+
395
+
396
+ def test_log_prob (self ):
397
+ value = torch .tensor ([2.0 , 3.0 ])
398
+ log_prob = self .studentt .log_prob (value )
399
+ y = (value - self .loc ) / self .scale
400
+ Z = (
401
+ self .scale .log ()
402
+ + 0.5 * self .df .log ()
403
+ + 0.5 * math .log (math .pi )
404
+ + torch .lgamma (0.5 * self .df )
405
+ - torch .lgamma (0.5 * (self .df + 1.0 ))
406
+ )
407
+ expected_log_prob = - 0.5 * (self .df + 1.0 ) * torch .log1p (y ** 2.0 / self .df ) - Z
408
+ self .assertTrue (torch .allclose (log_prob , expected_log_prob ))
409
+
410
+
411
+ def test_entropy (self ):
412
+ entropy = self .studentt .entropy ()
413
+ lbeta = (
414
+ torch .lgamma (0.5 * self .df )
415
+ + math .lgamma (0.5 )
416
+ - torch .lgamma (0.5 * (self .df + 1 ))
417
+ )
418
+ expected_entropy = (
419
+ self .scale .log ()
420
+ + 0.5
421
+ * (self .df + 1 )
422
+ * (torch .digamma (0.5 * (self .df + 1 )) - torch .digamma (0.5 * self .df ))
423
+ + 0.5 * self .df .log ()
424
+ + lbeta
425
+ )
426
+ self .assertTrue (torch .allclose (entropy , expected_entropy ))
427
+
428
+
429
+ def test_rsample (self ):
430
+ samples = self .studentt .rsample (sample_shape = torch .Size ([10 ]))
431
+ print (samples .shape )
432
+ # print(self.studentt.rsample(sample_shape=torch.Size([2])))
433
+ self .assertEqual (samples .shape , torch .Size ([10 , 2 ]))
434
+ self .assertTrue (samples .requires_grad ) # Check that gradients are tracked
435
+
436
+ def test_invalid_inputs (self ):
437
+ with self .assertRaises (ValueError ):
438
+ StudentT (torch .tensor ([- 1.0 , 1.0 ]), self .loc , self .scale ) #Negative df
439
+ with self .assertRaises (ValueError ):
440
+ self .studentt .log_prob ([1 , 2 ])
441
+
442
+
353
443
if __name__ == "__main__" :
354
- unittest .main ()
444
+ unittest .main ()
0 commit comments