Skip to content

Commit e11739d

Browse files
committed
added tests for student + added studentt + fix
1 parent aed0e88 commit e11739d

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

code/run_unittest.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
22
import math
33
import torch
4-
from irt.distributions import Normal, Gamma, MixtureSameFamily, Beta, Dirichlet
4+
import sys
5+
sys.path.append('../src')
6+
from irt.distributions import Normal, Gamma, MixtureSameFamily, Beta, Dirichlet, StudentT
57
from torch.distributions import Categorical, Independent
68

79

@@ -350,5 +352,93 @@ def test_invalid_inputs(self):
350352
self.dirichlet.log_prob(torch.tensor([0.2, 0.3, 0.6])) #Values don't sum to 1
351353

352354

355+
class TestStudentT(unittest.TestCase):
356+
357+
def setUp(self):
358+
self.df = torch.tensor([3.0, 5.0]).requires_grad_(True)
359+
self.loc = torch.tensor([1.0, 2.0]).requires_grad_(True)
360+
self.scale = torch.tensor([0.5, 1.0]).requires_grad_(True)
361+
self.studentt = StudentT(self.df, self.loc, self.scale, validate_args=True)
362+
363+
def test_init(self):
364+
studentt = StudentT(3.0, 1.0, 0.5)
365+
self.assertEqual(studentt.df, 3.0)
366+
self.assertEqual(studentt.loc, 1.0)
367+
self.assertEqual(studentt.scale, 0.5)
368+
self.assertEqual(studentt.gamma.concentration, 1.5) #Check Gamma initialization
369+
self.assertEqual(studentt.gamma.rate, 1.5)
370+
371+
def test_properties(self):
372+
df = torch.tensor([.3, 2.0])
373+
loc = torch.tensor([1.0, 2.0])
374+
scale = torch.tensor([0.5, 1.0])
375+
studentt = StudentT(df, loc, scale)
376+
self.assertTrue(torch.equal(studentt.mode, studentt.loc))
377+
# Check mean (undefined for df <= 1)
378+
# print(self.studentt.mean[0])
379+
self.assertTrue(torch.isnan(studentt.mean[0])) #Testing for nan values
380+
self.assertTrue(torch.allclose(studentt.mean[1], studentt.loc[1])) #Mean should be defined for df > 1
381+
382+
# Check variance (undefined for df <= 1, infinite for 1 < df <= 2)
383+
self.assertTrue(torch.isnan(studentt.variance[0]))
384+
self.assertTrue(torch.isinf(studentt.variance[1])) # Should be inf for 1 < df <=2
385+
self.assertTrue(torch.allclose(studentt.variance[1], (scale[1].pow(2) * df[1] / (df[1] - 2)))) #Should be defined for df > 2
386+
387+
388+
def test_expand(self):
389+
expanded_studentt = self.studentt.expand(torch.Size([2, 2]))
390+
self.assertEqual(expanded_studentt.batch_shape, torch.Size([2, 2]))
391+
self.assertTrue(torch.equal(expanded_studentt.df, self.df.expand([2, 2])))
392+
self.assertTrue(torch.equal(expanded_studentt.loc, self.loc.expand([2, 2])))
393+
self.assertTrue(torch.equal(expanded_studentt.scale, self.scale.expand([2, 2])))
394+
395+
396+
def test_log_prob(self):
397+
value = torch.tensor([2.0, 3.0])
398+
log_prob = self.studentt.log_prob(value)
399+
y = (value - self.loc) / self.scale
400+
Z = (
401+
self.scale.log()
402+
+ 0.5 * self.df.log()
403+
+ 0.5 * math.log(math.pi)
404+
+ torch.lgamma(0.5 * self.df)
405+
- torch.lgamma(0.5 * (self.df + 1.0))
406+
)
407+
expected_log_prob = -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
408+
self.assertTrue(torch.allclose(log_prob, expected_log_prob))
409+
410+
411+
def test_entropy(self):
412+
entropy = self.studentt.entropy()
413+
lbeta = (
414+
torch.lgamma(0.5 * self.df)
415+
+ math.lgamma(0.5)
416+
- torch.lgamma(0.5 * (self.df + 1))
417+
)
418+
expected_entropy = (
419+
self.scale.log()
420+
+ 0.5
421+
* (self.df + 1)
422+
* (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
423+
+ 0.5 * self.df.log()
424+
+ lbeta
425+
)
426+
self.assertTrue(torch.allclose(entropy, expected_entropy))
427+
428+
429+
def test_rsample(self):
430+
samples = self.studentt.rsample(sample_shape=torch.Size([10]))
431+
print(samples.shape)
432+
# print(self.studentt.rsample(sample_shape=torch.Size([2])))
433+
self.assertEqual(samples.shape, torch.Size([10, 2]))
434+
self.assertTrue(samples.requires_grad) # Check that gradients are tracked
435+
436+
def test_invalid_inputs(self):
437+
with self.assertRaises(ValueError):
438+
StudentT(torch.tensor([-1.0, 1.0]), self.loc, self.scale) #Negative df
439+
with self.assertRaises(ValueError):
440+
self.studentt.log_prob([1, 2])
441+
442+
353443
if __name__ == "__main__":
354-
unittest.main()
444+
unittest.main()

src/irt/distributions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,16 @@ def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor:
505505
sample_shape (_size): Shape of the sample.
506506
507507
Returns:
508-
torch.Tensor: Reparameterized sample, enabling gradient tracking.
508+
torch.Tensor: Reparameterized sample, enabling gradient tracking.
509509
"""
510-
shape = self._extended_shape(sample_shape)
511-
sigma = self.gamma.rsample(shape) # Sample from auxiliary Gamma distribution
512-
x = self.loc.detach() + self.scale.detach() * Normal(0, sigma).rsample(shape)
510+
loc = self.loc.expand(self._extended_shape(sample_shape))
511+
scale = self.scale.expand(self._extended_shape(sample_shape))
512+
513+
# Sample from auxiliary Gamma distribution
514+
sigma = self.gamma.rsample(sample_shape)
515+
516+
# Sample from Normal distribution (shape must match after broadcasting)
517+
x = loc + scale * Normal(0, sigma).rsample()
513518

514519
transform = self._transform(x.detach()) # Standardize the sample
515520
surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient

0 commit comments

Comments
 (0)