diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md
new file mode 100644
index 0000000..f0dfbda
--- /dev/null
+++ b/.ipynb_checkpoints/README-checkpoint.md
@@ -0,0 +1,75 @@
+# Implicit Reparametrization Trick
+
+
+

+
+
+
+
+
+ Title |
+ Implicit Reparametrization Trick for BMM |
+
+
+ Authors |
+ Matvei Kreinin, Maria Nikitina, Petr Babkin, Iryna Zabarianska |
+
+
+ Consultant |
+ Oleg Bakhteev, PhD |
+
+
+
+
+
+
+
+
+## Description
+
+This repository implements an educational project for the Bayesian Multimodeling course. It implements algorithms for sampling from various distributions, using the implicit reparameterization trick.
+
+## Scope
+We plan to implement the following distributions in our library:
+- [x] Gaussian normal distribution (*)
+- [x] Dirichlet distribution (Beta distributions)(\*)
+- [x] Mixture of the same family distributions (**)
+- [x] Student's t-distribution (**) (\*)
+- [x] VonMises distribution (***)
+- [ ] Sampling from an arbitrary factorized distribution (***)
+
+(\*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison
+
+(\*\*) - this distribution is added as a backup, their inclusion is questionable
+
+(\*\*\*) - this distribution is not very clear in implementation, its inclusion is questionable
+
+## Stack
+
+We plan to inherit from the torch.distribution.Distribution class, so we need to implement all the methods that are present in that class.
+
+## Usage
+In this example, we demonstrate the application of our library using a Variational Autoencoder (VAE) model, where the latent layer is modified by a normal distribution.
+```
+>>> import torch.distributions.implicit as irt
+>>> params = Encoder(inputs)
+>>> gauss = irt.Normal(*params)
+>>> deviated = gauss.rsample()
+>>> outputs = Decoder(deviated)
+```
+In this example, we demonstrate the use of a mixture of distributions using our library.
+```
+>>> import irt
+>>> params = Encoder(inputs)
+>>> mix = irt.Mixture([irt.Normal(*params), irt.Dirichlet(*params)])
+>>> deviated = mix.rsample()
+>>> outputs = Decoder(deviated)
+```
+
+## Links
+- [LinkReview](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/linkreview.md)
+- [Plan of project](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/planning.md)
+- [BlogPost](blogpost/Blog_post_sketch.pdf)
+- [Documentation](https://intsystems.github.io/implicit-reparameterization-trick/)
+- [Matvei Kreinin](https://github.com/kreininmv), [Maria Nikitina](https://github.com/NikitinaMaria), [Petr Babkin](https://github.com/petr-parker), [Iryna Zabarianska](https://github.com/Akshiira)
+
diff --git a/README.md b/README.md
index 38f922a..f0dfbda 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,10 @@
# Implicit Reparametrization Trick
+
+

+
+
+
+


@@ -25,11 +31,12 @@ This repository implements an educational project for the Bayesian Multimodeling
## Scope
We plan to implement the following distributions in our library:
-- Gaussian normal distribution (*)
-- Dirichlet distribution (Beta distributions)(\*)
-- Sampling from a mixture of distributions
-- Sampling from the Student's t-distribution (**) (\*)
-- Sampling from an arbitrary factorized distribution (***)
+- [x] Gaussian normal distribution (*)
+- [x] Dirichlet distribution (Beta distributions)(\*)
+- [x] Mixture of the same family distributions (**)
+- [x] Student's t-distribution (**) (\*)
+- [x] VonMises distribution (***)
+- [ ] Sampling from an arbitrary factorized distribution (***)
(\*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison
@@ -64,4 +71,5 @@ In this example, we demonstrate the use of a mixture of distributions using our
- [Plan of project](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/planning.md)
- [BlogPost](blogpost/Blog_post_sketch.pdf)
- [Documentation](https://intsystems.github.io/implicit-reparameterization-trick/)
+- [Matvei Kreinin](https://github.com/kreininmv), [Maria Nikitina](https://github.com/NikitinaMaria), [Petr Babkin](https://github.com/petr-parker), [Iryna Zabarianska](https://github.com/Akshiira)
diff --git a/code/.ipynb_checkpoints/run_unittest-checkpoint.py b/code/.ipynb_checkpoints/run_unittest-checkpoint.py
new file mode 100644
index 0000000..bca15b4
--- /dev/null
+++ b/code/.ipynb_checkpoints/run_unittest-checkpoint.py
@@ -0,0 +1,444 @@
+import unittest
+import math
+import torch
+import sys
+sys.path.append('../src')
+from irt.distributions import Normal, Gamma, MixtureSameFamily, Beta, Dirichlet, StudentT
+from torch.distributions import Categorical, Independent
+
+
+class TestNormalDistribution(unittest.TestCase):
+ def setUp(self):
+ self.loc = torch.tensor([0.0, 1.0]).requires_grad_(True)
+ self.scale = torch.tensor([1.0, 2.0]).requires_grad_(True)
+ self.normal = Normal(self.loc, self.scale)
+
+ def test_init(self):
+ normal = Normal(0.0, 1.0)
+ self.assertEqual(normal.loc, 0.0)
+ self.assertEqual(normal.scale, 1.0)
+ self.assertEqual(normal.batch_shape, torch.Size())
+
+ normal = Normal(torch.tensor([0.0, 1.0]), torch.tensor([1.0, 2.0]))
+ self.assertTrue(torch.equal(normal.loc, torch.tensor([0.0, 1.0])))
+ self.assertTrue(torch.equal(normal.scale, torch.tensor([1.0, 2.0])))
+ self.assertEqual(normal.batch_shape, torch.Size([2]))
+
+ def test_properties(self):
+ self.assertTrue(torch.equal(self.normal.mean, self.loc))
+ self.assertTrue(torch.equal(self.normal.mode, self.loc))
+ self.assertTrue(torch.equal(self.normal.stddev, self.scale))
+ self.assertTrue(torch.equal(self.normal.variance, self.scale**2))
+
+ def test_entropy(self):
+ entropy = self.normal.entropy()
+ expected_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
+ self.assertTrue(torch.allclose(entropy, expected_entropy))
+
+ def test_cdf(self):
+ value = torch.tensor([0.0, 2.0])
+ cdf = self.normal.cdf(value)
+ expected_cdf = 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2))))
+ self.assertTrue(torch.allclose(cdf, expected_cdf))
+
+ def test_expand(self):
+ expanded_normal = self.normal.expand(torch.Size([3, 2]))
+ self.assertEqual(expanded_normal.batch_shape, torch.Size([3, 2]))
+ self.assertTrue(torch.equal(expanded_normal.loc, self.loc.expand([3, 2])))
+ self.assertTrue(torch.equal(expanded_normal.scale, self.scale.expand([3, 2])))
+
+ def test_icdf(self):
+ value = torch.tensor([0.2, 0.8])
+ icdf = self.normal.icdf(value)
+ expected_icdf = self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
+ self.assertTrue(torch.allclose(icdf, expected_icdf))
+
+ def test_log_prob(self):
+ value = torch.tensor([0.0, 2.0])
+ log_prob = self.normal.log_prob(value)
+ var = self.scale**2
+ log_scale = self.scale.log()
+ expected_log_prob = -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
+ self.assertTrue(torch.allclose(log_prob, expected_log_prob))
+
+ def test_sample(self):
+ samples = self.normal.sample(sample_shape=torch.Size([100]))
+ self.assertEqual(samples.shape, torch.Size([100, 2])) # Check shape
+ emperic_mean = samples.mean(dim=0)
+ self.assertTrue((emperic_mean < self.normal.mean + self.normal.scale).all())
+ self.assertTrue((self.normal.mean - self.normal.scale < emperic_mean).all())
+
+ def test_rsample(self):
+ samples = self.normal.rsample(sample_shape=torch.Size([10]))
+ self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape
+ self.assertTrue(samples.requires_grad) # Check gradient tracking
+
+
+class TestGammaDistribution(unittest.TestCase):
+ def setUp(self):
+ self.concentration = torch.tensor([1.0, 2.0]).requires_grad_(True)
+ self.rate = torch.tensor([1.0, 0.5]).requires_grad_(True)
+ self.gamma = Gamma(self.concentration, self.rate)
+
+ def test_init(self):
+ gamma = Gamma(1.0, 1.0)
+ self.assertEqual(gamma.concentration, 1.0)
+ self.assertEqual(gamma.rate, 1.0)
+ self.assertEqual(gamma.batch_shape, torch.Size())
+
+ gamma = Gamma(torch.tensor([1.0, 2.0]), torch.tensor([1.0, 0.5]))
+ self.assertTrue(torch.equal(gamma.concentration, torch.tensor([1.0, 2.0])))
+ self.assertTrue(torch.equal(gamma.rate, torch.tensor([1.0, 0.5])))
+ self.assertEqual(gamma.batch_shape, torch.Size([2]))
+
+ def test_properties(self):
+ self.assertTrue(torch.allclose(self.gamma.mean, self.concentration / self.rate))
+ self.assertTrue(torch.allclose(self.gamma.mode, ((self.concentration - 1) / self.rate).clamp(min=0)))
+ self.assertTrue(torch.allclose(self.gamma.variance, self.concentration / self.rate.pow(2)))
+
+ def test_expand(self):
+ expanded_gamma = self.gamma.expand(torch.Size([3, 2]))
+ self.assertEqual(expanded_gamma.batch_shape, torch.Size([3, 2]))
+ self.assertTrue(torch.equal(expanded_gamma.concentration, self.concentration.expand([3, 2])))
+ self.assertTrue(torch.equal(expanded_gamma.rate, self.rate.expand([3, 2])))
+
+ def test_rsample(self):
+ samples = self.gamma.rsample(sample_shape=torch.Size([10]))
+ self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape
+ self.assertTrue(samples.requires_grad) #Check gradient tracking
+
+
+ def test_log_prob(self):
+ value = torch.tensor([1.0, 2.0])
+ log_prob = self.gamma.log_prob(value)
+ expected_log_prob = (
+ torch.xlogy(self.concentration, self.rate)
+ + torch.xlogy(self.concentration - 1, value)
+ - self.rate * value
+ - torch.lgamma(self.concentration)
+ )
+ self.assertTrue(torch.allclose(log_prob, expected_log_prob))
+
+ def test_entropy(self):
+ entropy = self.gamma.entropy()
+ expected_entropy = (
+ self.concentration
+ - torch.log(self.rate)
+ + torch.lgamma(self.concentration)
+ + (1.0 - self.concentration) * torch.digamma(self.concentration)
+ )
+ self.assertTrue(torch.allclose(entropy, expected_entropy))
+
+ def test_natural_params(self):
+ natural_params = self.gamma._natural_params
+ expected_natural_params = (self.concentration - 1, -self.rate)
+ self.assertTrue(torch.equal(natural_params[0], expected_natural_params[0]))
+ self.assertTrue(torch.equal(natural_params[1], expected_natural_params[1]))
+
+ def test_log_normalizer(self):
+ x, y = self.gamma._natural_params
+ log_normalizer = self.gamma._log_normalizer(x, y)
+ expected_log_normalizer = torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
+ self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer))
+
+ def test_cdf(self):
+ value = torch.tensor([1.0, 2.0])
+ cdf = self.gamma.cdf(value)
+ expected_cdf = torch.special.gammainc(self.concentration, self.rate * value)
+ self.assertTrue(torch.allclose(cdf, expected_cdf))
+
+
+ def test_invalid_inputs(self):
+ with self.assertRaises(ValueError):
+ Gamma(torch.tensor([-1.0, 1.0]), self.rate) # Negative concentration
+ with self.assertRaises(ValueError):
+ Gamma(self.concentration, torch.tensor([-1.0, 1.0])) # Negative rate
+ with self.assertRaises(ValueError):
+ self.gamma.log_prob(torch.tensor([-1.0, 1.0])) # Negative value
+
+class TestMixtureSameFamily(unittest.TestCase):
+ def setUp(self):
+ # Use simple distributions for testing. Replace with your desired components
+ component_dist = Normal(torch.tensor([0.0, 1.0]).requires_grad_(True), torch.tensor([1.0, 2.0]).requires_grad_(True))
+ mixture_dist = Categorical(torch.tensor([0.6, 0.4]))
+
+ self.mixture = MixtureSameFamily(mixture_dist, component_dist)
+
+ def test_rsample_event_size_1(self):
+ samples = self.mixture.rsample(sample_shape=torch.Size([10]))
+ self.assertEqual(samples.shape, torch.Size([10]))
+ self.assertTrue(samples.requires_grad) # Ensure gradient tracking
+
+ def test_rsample_event_size_greater_than_1(self):
+ #Create a mixture with event_size > 1 (e.g., using Independent(Normal(...),1) )
+ component_dist = Independent(Normal(torch.tensor([[0.0, 1.0], [2., 3.]]).requires_grad_(True), torch.tensor([[1.0, 2.0], [2., 3.]]).requires_grad_(True)), 1)
+ mixture_dist = Categorical(torch.tensor([0.6, 0.4]).requires_grad_(True))
+ mixture = MixtureSameFamily(mixture_dist, component_dist)
+ samples = mixture.rsample(sample_shape=torch.Size([10]))
+ self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape
+ self.assertTrue(samples.requires_grad) # Ensure gradient tracking
+
+ def test_distributional_transform(self):
+ # Test cases for different input shapes and component distributions
+ # Add assertions to check the output of _distributional_transform.
+
+ x = torch.randn(10,2)
+ transform = self.mixture._distributional_transform(x)
+ # ADD ASSERTION HERE. The transform should be a tensor of the correct shape, depending on event shape of component distribution
+ self.assertEqual(transform.shape,torch.Size([10, 2]))
+
+
+ def test_invalid_component(self):
+ # Test with a component distribution that doesn't have rsample
+ class NoRsampleDist(object):
+ has_rsample = False
+ with self.assertRaises(ValueError):
+ MixtureSameFamily(Categorical(torch.tensor([0.6, 0.4])), NoRsampleDist())
+
+ def test_log_cdf_multivariate(self):
+ # Test with a multivariate component distribution (e.g., Independent(Normal(...),1))
+ component_dist = Independent(Normal(loc=torch.zeros(2, 2), scale=torch.ones(2, 2)),1)
+ mixture_dist = Categorical(torch.tensor([0.6, 0.4]))
+ mixture = MixtureSameFamily(mixture_dist, component_dist)
+ x = torch.tensor([[0.5, 1.0], [1.0, 0.5]])
+ log_cdf = mixture._log_cdf(x)
+ # ADD ASSERTION(S) HERE to check the calculated log_cdf values
+ # self.assertTrue(torch.allclose(log_cdf, torch.tensor([expected_value_1, expected_value_2]), atol=1e-4))
+
+class TestBetaDistribution(unittest.TestCase):
+ def setUp(self):
+ self.concentration1 = torch.tensor([1.0, 2.0], requires_grad=True)
+ self.concentration0 = torch.tensor([2.0, 1.0], requires_grad=True)
+ self.beta = Beta(self.concentration1, self.concentration0)
+ self._dirichlet = Dirichlet(torch.stack([self.concentration1, self.concentration0], -1)) # Initialize _dirichlet
+
+ def test_init(self):
+ beta = Beta(torch.tensor(1.0), torch.tensor(2.0))
+ self.assertEqual(beta.concentration1, 1.0)
+ self.assertEqual(beta.concentration0, 2.0)
+ self.assertEqual(beta._gamma1.concentration, 1.0) #Check Gamma parameters
+ self.assertEqual(beta._gamma0.concentration, 2.0)
+
+
+ def test_properties(self):
+ self.assertTrue(torch.allclose(self.beta.mean, self.concentration1 / (self.concentration1 + self.concentration0)))
+ self.assertTrue(torch.allclose(self.beta.mode, (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2)))
+ total = self.concentration1 + self.concentration0
+ self.assertTrue(torch.allclose(self.beta.variance, self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))))
+
+ def test_expand(self):
+ expanded_beta = self.beta.expand(torch.Size([3, 2]))
+ self.assertEqual(expanded_beta.batch_shape, torch.Size([3, 2]))
+ self.assertTrue(torch.equal(expanded_beta._gamma1.concentration, self.concentration1.expand([3, 2])))
+ self.assertTrue(torch.equal(expanded_beta._gamma0.concentration, self.concentration0.expand([3, 2])))
+
+ def test_rsample(self):
+ samples = self.beta.rsample(sample_shape=torch.Size([10]))
+ self.assertEqual(samples.shape, torch.Size([10, 2]))
+ self.assertTrue(samples.requires_grad) #check grad tracking
+
+ def test_log_prob(self):
+ value = torch.tensor([0.5, 0.7])
+ log_prob = self.beta.log_prob(value)
+ heads_tails = torch.stack([value, 1.0 - value], -1)
+ expected_log_prob = self._dirichlet.log_prob(heads_tails) #Use initialized _dirichlet
+ self.assertTrue(torch.allclose(log_prob, expected_log_prob))
+
+ def test_entropy(self):
+ entropy = self.beta.entropy()
+ expected_entropy = self._dirichlet.entropy() #Use initialized _dirichlet
+ self.assertTrue(torch.allclose(entropy, expected_entropy))
+
+
+ def test_natural_params(self):
+ natural_params = self.beta._natural_params
+ self.assertTrue(torch.equal(natural_params[0], self.concentration1))
+ self.assertTrue(torch.equal(natural_params[1], self.concentration0))
+
+ def test_log_normalizer(self):
+ x, y = self.beta._natural_params
+ log_normalizer = self.beta._log_normalizer(x, y)
+ expected_log_normalizer = torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
+ self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer))
+
+
+ def test_invalid_inputs(self):
+ with self.assertRaises(ValueError):
+ Beta(torch.tensor([-1.0, 1.0]), self.concentration0) # Negative concentration1
+ with self.assertRaises(ValueError):
+ Beta(self.concentration1, torch.tensor([-1.0, 1.0])) # Negative concentration0
+ with self.assertRaises(ValueError):
+ self.beta.log_prob(torch.tensor([-0.1, 0.5])) # Value outside [0,1]
+ with self.assertRaises(ValueError):
+ self.beta.log_prob(torch.tensor([1.1, 0.5]))
+
+class TestDirichlet(unittest.TestCase):
+ def setUp(self):
+ self.concentration = torch.tensor([1.0, 2.0, 3.0]).requires_grad_(True)
+ self.dirichlet = Dirichlet(self.concentration)
+
+ def test_init(self):
+ with self.assertRaises(ValueError):
+ Dirichlet(torch.tensor([]), validate_args=True) # Not enough dimensions
+ dirichlet = Dirichlet(self.concentration)
+ self.assertTrue(torch.equal(dirichlet.concentration, self.concentration))
+ self.assertEqual(dirichlet.batch_shape, torch.Size())
+ self.assertEqual(dirichlet.event_shape, torch.Size([3]))
+
+ def test_expand(self):
+ expanded_dirichlet = self.dirichlet.expand(torch.Size([2, 3]))
+ self.assertEqual(expanded_dirichlet.batch_shape, torch.Size([2, 3]))
+ self.assertEqual(expanded_dirichlet.event_shape, torch.Size([3]))
+ self.assertTrue(torch.equal(expanded_dirichlet.concentration, self.concentration.expand(torch.Size([2, 3, 3]))))
+
+
+ def test_log_prob(self):
+ value = torch.tensor([0.2, 0.3, 0.5])
+ log_prob = self.dirichlet.log_prob(value)
+ expected_log_prob = (
+ torch.xlogy(self.concentration - 1.0, value).sum(-1)
+ + torch.lgamma(self.concentration.sum(-1))
+ - torch.lgamma(self.concentration).sum(-1)
+ )
+ self.assertTrue(torch.allclose(log_prob, expected_log_prob))
+
+ def test_mean(self):
+ mean = self.dirichlet.mean
+ expected_mean = self.concentration / self.concentration.sum(-1, True)
+ self.assertTrue(torch.allclose(mean, expected_mean))
+
+ def test_mode(self):
+ mode = self.dirichlet.mode
+ concentrationm1 = (self.concentration - 1).clamp(min=0.0)
+ expected_mode = concentrationm1 / concentrationm1.sum(-1, True)
+ self.assertTrue(torch.allclose(mode, expected_mode))
+
+ def test_variance(self):
+ variance = self.dirichlet.variance
+ con0 = self.concentration.sum(-1, True)
+ expected_variance = (
+ self.concentration
+ * (con0 - self.concentration)
+ / (con0.pow(2) * (con0 + 1))
+ )
+ self.assertTrue(torch.allclose(variance, expected_variance))
+
+ def test_entropy(self):
+ entropy = self.dirichlet.entropy()
+ k = self.concentration.size(-1)
+ a0 = self.concentration.sum(-1)
+ expected_entropy = (
+ torch.lgamma(self.concentration).sum(-1)
+ - torch.lgamma(a0)
+ - (k - a0) * torch.digamma(a0)
+ - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
+ )
+ self.assertTrue(torch.allclose(entropy, expected_entropy))
+
+
+ def test_natural_params(self):
+ natural_params = self.dirichlet._natural_params
+ self.assertTrue(torch.equal(natural_params[0], self.concentration))
+
+ def test_log_normalizer(self):
+ log_normalizer = self.dirichlet._log_normalizer(self.concentration)
+ expected_log_normalizer = torch.lgamma(self.concentration).sum(-1) - torch.lgamma(self.concentration.sum(-1))
+ self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer))
+
+ def test_invalid_inputs(self):
+ with self.assertRaises(ValueError):
+ Dirichlet(torch.tensor([1.0, -1.0, 3.0])) #Negative Concentration
+ with self.assertRaises(ValueError):
+ self.dirichlet.log_prob(torch.tensor([0.2, 0.3, 0.6])) #Values don't sum to 1
+
+
+class TestStudentT(unittest.TestCase):
+
+ def setUp(self):
+ self.df = torch.tensor([3.0, 5.0]).requires_grad_(True)
+ self.loc = torch.tensor([1.0, 2.0]).requires_grad_(True)
+ self.scale = torch.tensor([0.5, 1.0]).requires_grad_(True)
+ self.studentt = StudentT(self.df, self.loc, self.scale, validate_args=True)
+
+ def test_init(self):
+ studentt = StudentT(3.0, 1.0, 0.5)
+ self.assertEqual(studentt.df, 3.0)
+ self.assertEqual(studentt.loc, 1.0)
+ self.assertEqual(studentt.scale, 0.5)
+ self.assertEqual(studentt.gamma.concentration, 1.5) #Check Gamma initialization
+ self.assertEqual(studentt.gamma.rate, 1.5)
+
+ def test_properties(self):
+ df = torch.tensor([.3, 2.0])
+ loc = torch.tensor([1.0, 2.0])
+ scale = torch.tensor([0.5, 1.0])
+ studentt = StudentT(df, loc, scale)
+ self.assertTrue(torch.equal(studentt.mode, studentt.loc))
+ # Check mean (undefined for df <= 1)
+ # print(self.studentt.mean[0])
+ self.assertTrue(torch.isnan(studentt.mean[0])) #Testing for nan values
+ self.assertTrue(torch.allclose(studentt.mean[1], studentt.loc[1])) #Mean should be defined for df > 1
+
+ # Check variance (undefined for df <= 1, infinite for 1 < df <= 2)
+ self.assertTrue(torch.isnan(studentt.variance[0]))
+ self.assertTrue(torch.isinf(studentt.variance[1])) # Should be inf for 1 < df <=2
+ self.assertTrue(torch.allclose(studentt.variance[1], (scale[1].pow(2) * df[1] / (df[1] - 2)))) #Should be defined for df > 2
+
+
+ def test_expand(self):
+ expanded_studentt = self.studentt.expand(torch.Size([2, 2]))
+ self.assertEqual(expanded_studentt.batch_shape, torch.Size([2, 2]))
+ self.assertTrue(torch.equal(expanded_studentt.df, self.df.expand([2, 2])))
+ self.assertTrue(torch.equal(expanded_studentt.loc, self.loc.expand([2, 2])))
+ self.assertTrue(torch.equal(expanded_studentt.scale, self.scale.expand([2, 2])))
+
+
+ def test_log_prob(self):
+ value = torch.tensor([2.0, 3.0])
+ log_prob = self.studentt.log_prob(value)
+ y = (value - self.loc) / self.scale
+ Z = (
+ self.scale.log()
+ + 0.5 * self.df.log()
+ + 0.5 * math.log(math.pi)
+ + torch.lgamma(0.5 * self.df)
+ - torch.lgamma(0.5 * (self.df + 1.0))
+ )
+ expected_log_prob = -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
+ self.assertTrue(torch.allclose(log_prob, expected_log_prob))
+
+
+ def test_entropy(self):
+ entropy = self.studentt.entropy()
+ lbeta = (
+ torch.lgamma(0.5 * self.df)
+ + math.lgamma(0.5)
+ - torch.lgamma(0.5 * (self.df + 1))
+ )
+ expected_entropy = (
+ self.scale.log()
+ + 0.5
+ * (self.df + 1)
+ * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
+ + 0.5 * self.df.log()
+ + lbeta
+ )
+ self.assertTrue(torch.allclose(entropy, expected_entropy))
+
+
+ def test_rsample(self):
+ samples = self.studentt.rsample(sample_shape=torch.Size([10]))
+ print(samples.shape)
+ # print(self.studentt.rsample(sample_shape=torch.Size([2])))
+ self.assertEqual(samples.shape, torch.Size([10, 2]))
+ self.assertTrue(samples.requires_grad) # Check that gradients are tracked
+
+ def test_invalid_inputs(self):
+ with self.assertRaises(ValueError):
+ StudentT(torch.tensor([-1.0, 1.0]), self.loc, self.scale) #Negative df
+ with self.assertRaises(ValueError):
+ self.studentt.log_prob([1, 2])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/images/implicit.webp b/images/implicit.webp
new file mode 100644
index 0000000..06d71c8
Binary files /dev/null and b/images/implicit.webp differ
diff --git a/src/irt/.ipynb_checkpoints/distributions-checkpoint.py b/src/irt/.ipynb_checkpoints/distributions-checkpoint.py
new file mode 100644
index 0000000..43197d9
--- /dev/null
+++ b/src/irt/.ipynb_checkpoints/distributions-checkpoint.py
@@ -0,0 +1,1470 @@
+# mypy: allow-untyped-defs
+import math
+from numbers import Number, Real
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch.autograd.functional import jacobian
+from torch.distributions import (
+ Bernoulli,
+ Binomial,
+ ContinuousBernoulli,
+ Distribution,
+ Geometric,
+ NegativeBinomial,
+ RelaxedBernoulli,
+ constraints,
+)
+from torch.distributions.exp_family import ExponentialFamily
+from torch.distributions.utils import broadcast_all, lazy_property
+from torch.types import _size
+from torch.distributions.distribution import Distribution
+
+default_size = torch.Size()
+
+
+class Beta(ExponentialFamily):
+ r"""
+ Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
+
+ Example::
+
+ >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
+ >>> m.sample()
+ tensor([0.1046])
+
+ Args:
+ concentration1 (float or Tensor): 1st concentration parameter of the distribution
+ (often referred to as alpha)
+ concentration0 (float or Tensor): 2nd concentration parameter of the distribution
+ (often referred to as beta)
+ """
+
+ arg_constraints = {
+ "concentration1": constraints.positive,
+ "concentration0": constraints.positive,
+ }
+ support = constraints.unit_interval
+ has_rsample = True
+
+ def __init__(
+ self, concentration1: torch.Tensor, concentration0: torch.Tensor, validate_args: Optional[bool] = None
+ ) -> None:
+ """
+ Initializes the Beta distribution with the given concentration parameters.
+
+ Args:
+ concentration1: First concentration parameter (alpha).
+ concentration0: Second concentration parameter (beta).
+ validate_args: If True, validates the distribution's parameters.
+ """
+ self.concentration1 = concentration1
+ self.concentration0 = concentration0
+ self._gamma1 = Gamma(self.concentration1, torch.ones_like(concentration1), validate_args=validate_args)
+ self._gamma0 = Gamma(self.concentration0, torch.ones_like(concentration0), validate_args=validate_args)
+ self._dirichlet = Dirichlet(torch.stack([self.concentration1, self.concentration0], -1))
+
+ super().__init__(self._gamma0._batch_shape, validate_args=validate_args)
+
+ def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta":
+ """
+ Expands the Beta distribution to a new batch shape.
+
+ Args:
+ batch_shape: Desired batch shape.
+ _instance: Instance to validate.
+
+ Returns:
+ A new Beta distribution instance with expanded parameters.
+ """
+ new = self._get_checked_instance(Beta, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new._gamma1 = self._gamma1.expand(batch_shape)
+ new._gamma0 = self._gamma0.expand(batch_shape)
+ super(Beta, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def mean(self) -> torch.Tensor:
+ """
+ Computes the mean of the Beta distribution.
+
+ Returns:
+ torch.Tensor: Mean of the distribution.
+ """
+ return self.concentration1 / (self.concentration1 + self.concentration0)
+
+ @property
+ def mode(self) -> torch.Tensor:
+ """
+ Computes the mode of the Beta distribution.
+
+ Returns:
+ torch.Tensor: Mode of the distribution.
+ """
+ return (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2)
+
+ @property
+ def variance(self) -> torch.Tensor:
+ """
+ Computes the variance of the Beta distribution.
+
+ Returns:
+ torch.Tensor: Variance of the distribution.
+ """
+ total = self.concentration1 + self.concentration0
+ return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
+
+ def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the Beta distribution.
+
+ Args:
+ sample_shape (_size): Shape of the sample.
+
+ Returns:
+ torch.Tensor: Sample from the Beta distribution.
+ """
+ z1 = self._gamma1.rsample(sample_shape)
+ z0 = self._gamma0.rsample(sample_shape)
+ return z1 / (z1 + z0)
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log probability density of a value under the Beta distribution.
+
+ Args:
+ value: Value to evaluate.
+
+ Returns:
+ Log probability of the value.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ heads_tails = torch.stack([value, 1.0 - value], -1)
+ return self._dirichlet.log_prob(heads_tails)
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Computes the entropy of the Beta distribution.
+
+ Returns:
+ Entropy of the distribution.
+ """
+ return self._dirichlet.entropy()
+
+ @property
+ def _natural_params(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Returns the natural parameters of the distribution.
+
+ Returns:
+ Natural parameters.
+ """
+ return self.concentration1, self.concentration0
+
+ def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log normalizer for the natural parameters.
+
+ Args:
+ x: Parameter 1.
+ y: Parameter 2.
+
+ Returns:
+ Log normalizer value.
+ """
+ return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
+
+
+class Dirichlet(ExponentialFamily):
+ """
+ Dirichlet distribution parameterized by a concentration vector.
+
+ The Dirichlet distribution is a multivariate generalization of the Beta distribution. It
+ is commonly used in Bayesian statistics, particularly for modeling proportions.
+ """
+
+ arg_constraints = {"concentration": constraints.independent(constraints.positive, 1)}
+ support = constraints.simplex
+ has_rsample = True
+
+ def __init__(self, concentration: torch.Tensor, validate_args: Optional[bool] = None) -> None:
+ """
+ Initializes the Dirichlet distribution.
+
+ Args:
+ concentration: Positive concentration parameter vector (alpha).
+ validate_args: If True, validates the distribution's parameters.
+ """
+ if torch.numel(concentration) < 1:
+ raise ValueError("`concentration` parameter must be at least one-dimensional.")
+ self.concentration = concentration
+ self.gamma = Gamma(self.concentration, torch.ones_like(self.concentration))
+ batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
+ super().__init__(batch_shape, event_shape, validate_args=validate_args)
+
+ @property
+ def mean(self) -> torch.Tensor:
+ """
+ Computes the mean of the Dirichlet distribution.
+
+ Returns:
+ Mean vector, calculated as `concentration / concentration.sum(-1, keepdim=True)`.
+ """
+ return self.concentration / self.concentration.sum(-1, keepdim=True)
+
+ @property
+ def mode(self) -> torch.Tensor:
+ """
+ Computes the mode of the Dirichlet distribution.
+
+ Note:
+ - The mode is defined only when all concentration values are > 1.
+ - For concentrations ≤ 1, the mode vector is clamped to enforce positivity.
+
+ Returns:
+ Mode vector.
+ """
+ concentration_minus_one = (self.concentration - 1).clamp(min=0.0)
+ mode = concentration_minus_one / concentration_minus_one.sum(-1, keepdim=True)
+ mask = (self.concentration < 1).all(dim=-1)
+ mode[mask] = F.one_hot(mode[mask].argmax(dim=-1), concentration_minus_one.shape[-1]).to(mode)
+ return mode
+
+ @property
+ def variance(self) -> torch.Tensor:
+ """
+ Computes the variance of the Dirichlet distribution.
+
+ Returns:
+ Variance vector for each component.
+ """
+ total_concentration = self.concentration.sum(-1, keepdim=True)
+ return (
+ self.concentration
+ * (total_concentration - self.concentration)
+ / (total_concentration.pow(2) * (total_concentration + 1))
+ )
+
+ def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the Dirichlet distribution.
+
+ Args:
+ sample_shape (_size): Desired sample shape.
+
+ Returns:
+ torch.Tensor: A reparameterized sample.
+ """
+ z = self.gamma.rsample(sample_shape) # Sample from underlying Gamma distribution
+
+ return z / torch.sum(z, dim=-1, keepdims=True)
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log probability density for a given value.
+
+ Args:
+ value (torch.Tensor): Value to evaluate the log probability at.
+
+ Returns:
+ torch.Tensor: Log probability density of the value.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ return (
+ torch.xlogy(self.concentration - 1.0, value).sum(-1)
+ + torch.lgamma(self.concentration.sum(-1))
+ - torch.lgamma(self.concentration).sum(-1)
+ )
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Computes the entropy of the Dirichlet distribution.
+
+ Returns:
+ torch.Tensor: Entropy of the distribution.
+ """
+ k = self.concentration.size(-1)
+ total_concentration = self.concentration.sum(-1)
+ return (
+ torch.lgamma(self.concentration).sum(-1)
+ - torch.lgamma(total_concentration)
+ - (k - total_concentration) * torch.digamma(total_concentration)
+ - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
+ )
+
+ def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet":
+ """
+ Expands the distribution parameters to a new batch shape.
+
+ Args:
+ batch_shape (torch.Size): Desired batch shape.
+ _instance (Optional): Instance to validate.
+
+ Returns:
+ A new Dirichlet distribution instance with expanded parameters.
+ """
+ new = self._get_checked_instance(Dirichlet, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.concentration = self.concentration.expand(batch_shape + self.event_shape)
+ super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ @property
+ def _natural_params(self) -> tuple:
+ """
+ Returns the natural parameters of the distribution.
+
+ Returns:
+ tuple: Natural parameter tuple `(concentration,)`.
+ """
+ return (self.concentration,)
+
+ def _log_normalizer(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log normalizer for the natural parameters.
+
+ Args:
+ x (torch.Tensor): Natural parameter.
+
+ Returns:
+ torch.Tensor: Log normalizer value.
+ """
+ return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
+
+
+class StudentT(Distribution):
+ """
+ Student's t-distribution parameterized by degrees of freedom (df), location (loc), and scale (scale).
+
+ This distribution is commonly used for robust statistical modeling, particularly when the data
+ may have outliers or heavier tails than a Normal distribution.
+ """
+
+ arg_constraints = {
+ "df": constraints.positive,
+ "loc": constraints.real,
+ "scale": constraints.positive,
+ }
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(
+ self, df: torch.Tensor, loc: float = 0.0, scale: float = 1.0, validate_args: Optional[bool] = None
+ ) -> None:
+ """
+ Initializes the Student's t-distribution.
+
+ Args:
+ df (torch.Tensor): Degrees of freedom (must be positive).
+ loc (float or torch.Tensor): Location parameter (default: 0.0).
+ scale (float or torch.Tensor): Scale parameter (default: 1.0).
+ validate_args (Optional[bool]): If True, validates distribution parameters.
+ """
+ self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
+ self.gamma = Gamma(self.df * 0.5, self.df * 0.5)
+ batch_shape = self.df.size()
+ super().__init__(batch_shape, validate_args=validate_args)
+
+ @property
+ def mean(self) -> torch.Tensor:
+ """
+ Computes the mean of the distribution.
+
+ Note: The mean is undefined when `df <= 1`.
+
+ Returns:
+ torch.Tensor: Mean of the distribution, or NaN for undefined cases.
+ """
+ m = self.loc.clone(memory_format=torch.contiguous_format)
+ m[self.df <= 1] = float("nan") # Mean is undefined for df <= 1
+ return m
+
+ @property
+ def mode(self) -> torch.Tensor:
+ """
+ Computes the mode of the distribution.
+
+ Returns:
+ torch.Tensor: Mode of the distribution, which is equal to `loc`.
+ """
+ return self.loc
+
+ @property
+ def variance(self) -> torch.Tensor:
+ """
+ Computes the variance of the distribution.
+
+ Note:
+ - Variance is infinite for 1 < df <= 2.
+ - Variance is undefined (NaN) for df <= 1.
+
+ Returns:
+ torch.Tensor: Variance of the distribution, or appropriate values for edge cases.
+ """
+ m = self.df.clone(memory_format=torch.contiguous_format)
+ # Variance for df > 2
+ m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
+ # Infinite variance for 1 < df <= 2
+ m[(self.df <= 2) & (self.df > 1)] = float("inf")
+ # Undefined variance for df <= 1
+ m[self.df <= 1] = float("nan")
+ return m
+
+ def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT":
+ """
+ Expands the distribution parameters to a new batch shape.
+
+ Args:
+ batch_shape (torch.Size): Desired batch size for the expanded distribution.
+ _instance (Optional): Instance to validate.
+
+ Returns:
+ StudentT: A new StudentT distribution with expanded parameters.
+ """
+ new = self._get_checked_instance(StudentT, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.df = self.df.expand(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ super(StudentT, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log probability density for a given value.
+
+ Args:
+ value (torch.Tensor): Value to evaluate the log probability at.
+
+ Returns:
+ torch.Tensor: Log probability density of the given value.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ y = (value - self.loc) / self.scale
+ Z = (
+ self.scale.log()
+ + 0.5 * self.df.log()
+ + 0.5 * math.log(math.pi)
+ + torch.lgamma(0.5 * self.df)
+ - torch.lgamma(0.5 * (self.df + 1.0))
+ )
+ return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Computes the entropy of the Student's t-distribution.
+
+ Returns:
+ torch.Tensor: Entropy of the distribution.
+ """
+ lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
+ return (
+ self.scale.log()
+ + 0.5 * (self.df + 1) * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
+ + 0.5 * self.df.log()
+ + lbeta
+ )
+
+ def _transform(self, z: torch.Tensor) -> torch.Tensor:
+ """
+ Transforms an input tensor `z` to a standardized form based on the location and scale.
+
+ Args:
+ z (torch.Tensor): Input tensor to transform.
+
+ Returns:
+ torch.Tensor: Transformed tensor representing the standardized form.
+ """
+ return (z - self.loc) / self.scale
+
+ def _d_transform_d_z(self) -> torch.Tensor:
+ """
+ Computes the derivative of the transform function with respect to `z`.
+
+ Returns:
+ torch.Tensor: Reciprocal of the scale, representing the gradient for reparameterization.
+ """
+ return 1 / self.scale
+
+ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the Student's t-distribution.
+
+ Args:
+ sample_shape (_size): Shape of the sample.
+
+ Returns:
+ torch.Tensor: Reparameterized sample, enabling gradient tracking.
+ """
+ self.loc = self.loc.expand(self._extended_shape(sample_shape))
+ self.scale = self.scale.expand(self._extended_shape(sample_shape))
+
+ sigma = self.gamma.rsample()
+
+ # Sample from Normal distribution (shape must match after broadcasting)
+ x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape)
+
+ transform = self._transform(x.detach()) # Standardize the sample
+ surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient
+
+ return x + (surrogate_x - surrogate_x.detach())
+
+
+class Gamma(ExponentialFamily):
+ """
+ Gamma distribution parameterized by `concentration` (shape) and `rate` (inverse scale).
+ The Gamma distribution is often used to model the time until an event occurs,
+ and it is a continuous probability distribution defined for non-negative real values.
+ """
+
+ arg_constraints = {
+ "concentration": constraints.positive,
+ "rate": constraints.positive,
+ }
+ support = constraints.nonnegative
+ has_rsample = True
+ _mean_carrier_measure = 0
+
+ def __init__(
+ self,
+ concentration: torch.Tensor,
+ rate: torch.Tensor,
+ validate_args: Optional[bool] = None,
+ ) -> None:
+ """
+ Initializes the Gamma distribution.
+
+ Args:
+ concentration (torch.Tensor): Shape parameter of the distribution (often referred to as alpha).
+ rate (torch.Tensor): Rate parameter (inverse of scale, often referred to as beta).
+ validate_args (Optional[bool]): If True, validates the distribution's parameters.
+ """
+ self.concentration, self.rate = broadcast_all(concentration, rate)
+ if isinstance(concentration, Number) and isinstance(rate, Number):
+ batch_shape = torch.Size()
+ else:
+ batch_shape = self.concentration.size()
+ super().__init__(batch_shape, validate_args=validate_args)
+
+ @property
+ def mean(self) -> torch.Tensor:
+ """
+ Computes the mean of the Gamma distribution.
+
+ Returns:
+ torch.Tensor: Mean of the distribution, calculated as `concentration / rate`.
+ """
+ return self.concentration / self.rate
+
+ @property
+ def mode(self) -> torch.Tensor:
+ """
+ Computes the mode of the Gamma distribution.
+
+ Note:
+ - The mode is defined only for `concentration > 1`. For `concentration <= 1`,
+ the mode is clamped to 0.
+
+ Returns:
+ torch.Tensor: Mode of the distribution.
+ """
+ return ((self.concentration - 1) / self.rate).clamp(min=0)
+
+ @property
+ def variance(self) -> torch.Tensor:
+ """
+ Computes the variance of the Gamma distribution.
+
+ Returns:
+ torch.Tensor: Variance of the distribution, calculated as `concentration / rate^2`.
+ """
+ return self.concentration / self.rate.pow(2)
+
+ def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma":
+ """
+ Expands the distribution parameters to a new batch shape.
+
+ Args:
+ batch_shape (torch.Size): Desired batch shape.
+ _instance (Optional): Instance to validate.
+
+ Returns:
+ Gamma: A new Gamma distribution instance with expanded parameters.
+ """
+ new = self._get_checked_instance(Gamma, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.concentration = self.concentration.expand(batch_shape)
+ new.rate = self.rate.expand(batch_shape)
+ super(Gamma, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the Gamma distribution.
+
+ Args:
+ sample_shape (_size): Shape of the sample.
+
+ Returns:
+ torch.Tensor: A reparameterized sample.
+ """
+ shape = self._extended_shape(sample_shape)
+ concentration = self.concentration.expand(shape)
+ rate = self.rate.expand(shape)
+
+ # Generate a sample using the underlying C++ implementation for efficiency
+ value = torch._standard_gamma(concentration) / rate.detach()
+
+ # Detach u for surrogate computation
+ u = value.detach() * rate.detach() / rate
+ value = value + (u - u.detach())
+
+ # Ensure numerical stability for gradients
+ value.detach().clamp_(min=torch.finfo(value.dtype).tiny)
+ return value
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log probability density for a given value.
+
+ Args:
+ value (torch.Tensor): Value to evaluate the log probability at.
+
+ Returns:
+ torch.Tensor: Log probability density of the given value.
+ """
+ value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
+ if self._validate_args:
+ self._validate_sample(value)
+ return (
+ torch.xlogy(self.concentration, self.rate)
+ + torch.xlogy(self.concentration - 1, value)
+ - self.rate * value
+ - torch.lgamma(self.concentration)
+ )
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Computes the entropy of the Gamma distribution.
+
+ Returns:
+ torch.Tensor: Entropy of the distribution.
+ """
+ return (
+ self.concentration
+ - torch.log(self.rate)
+ + torch.lgamma(self.concentration)
+ + (1.0 - self.concentration) * torch.digamma(self.concentration)
+ )
+
+ @property
+ def _natural_params(self) -> tuple:
+ """
+ Returns the natural parameters of the distribution.
+
+ Returns:
+ tuple: Tuple of natural parameters `(concentration - 1, -rate)`.
+ """
+ return self.concentration - 1, -self.rate
+
+ def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log normalizer for the natural parameters.
+
+ Args:
+ x (torch.Tensor): First natural parameter.
+ y (torch.Tensor): Second natural parameter.
+
+ Returns:
+ torch.Tensor: Log normalizer value.
+ """
+ return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
+
+ def cdf(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the cumulative distribution function (CDF) for the Gamma distribution.
+
+ Args:
+ value (torch.Tensor): Value to evaluate the CDF at.
+
+ Returns:
+ torch.Tensor: CDF of the given value.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ return torch.special.gammainc(self.concentration, self.rate * value)
+
+
+class Normal(ExponentialFamily):
+ """
+ Represents the Normal (Gaussian) distribution with specified mean (loc) and standard deviation (scale).
+ Inherits from PyTorch's ExponentialFamily distribution class.
+ """
+
+ has_rsample = True
+ arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
+ support = constraints.real
+
+ def __init__(
+ self,
+ loc: torch.Tensor,
+ scale: torch.Tensor,
+ validate_args: Optional[bool] = None,
+ ) -> None:
+ """
+ Initializes the Normal distribution.
+
+ Args:
+ loc (torch.Tensor): Mean (location) parameter of the distribution.
+ scale (torch.Tensor): Standard deviation (scale) parameter of the distribution.
+ validate_args (Optional[bool]): If True, checks the distribution parameters for validity.
+ """
+ self.loc, self.scale = broadcast_all(loc, scale)
+ # Determine batch shape based on the type of `loc` and `scale`.
+ batch_shape = torch.Size() if isinstance(loc, Number) and isinstance(scale, Number) else self.loc.size()
+ super().__init__(batch_shape, validate_args=validate_args)
+
+ @property
+ def mean(self) -> torch.Tensor:
+ """
+ Returns the mean of the distribution.
+
+ Returns:
+ torch.Tensor: The mean (location) parameter `loc`.
+ """
+ return self.loc
+
+ @property
+ def mode(self) -> torch.Tensor:
+ """
+ Returns the mode of the distribution.
+
+ Returns:
+ torch.Tensor: The mode (equal to `loc` in a Normal distribution).
+ """
+ return self.loc
+
+ @property
+ def stddev(self) -> torch.Tensor:
+ """
+ Returns the standard deviation of the distribution.
+
+ Returns:
+ torch.Tensor: The standard deviation (scale) parameter `scale`.
+ """
+ return self.scale
+
+ @property
+ def variance(self) -> torch.Tensor:
+ """
+ Returns the variance of the distribution.
+
+ Returns:
+ torch.Tensor: The variance, computed as `scale ** 2`.
+ """
+ return self.stddev.pow(2)
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Computes the entropy of the distribution.
+
+ Returns:
+ torch.Tensor: The entropy of the Normal distribution, which is a measure of uncertainty.
+ """
+ return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
+
+ def cdf(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the cumulative distribution function (CDF) of the distribution at a given value.
+
+ Args:
+ value (torch.Tensor): The value at which to evaluate the CDF.
+
+ Returns:
+ torch.Tensor: The probability that a random variable from the distribution is less than or equal to `value`.
+ """
+ return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2))))
+
+ def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal":
+ """
+ Expands the distribution parameters to a new batch shape.
+
+ Args:
+ batch_shape (torch.Size): Desired batch size for the expanded distribution.
+ _instance (Optional): Instance to check for validity.
+
+ Returns:
+ Normal: A new Normal distribution with parameters expanded to the specified batch shape.
+ """
+ new = self._get_checked_instance(Normal, _instance)
+ batch_shape = torch.Size(batch_shape)
+ new.loc = self.loc.expand(batch_shape)
+ new.scale = self.scale.expand(batch_shape)
+ super(Normal, new).__init__(batch_shape, validate_args=False)
+ new._validate_args = self._validate_args
+ return new
+
+ def icdf(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the inverse cumulative distribution function (quantile function) at a given value.
+
+ Args:
+ value (torch.Tensor): The probability value at which to evaluate the inverse CDF.
+
+ Returns:
+ torch.Tensor: The quantile corresponding to `value`.
+ """
+ return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the log probability density of the distribution at a given value.
+
+ Args:
+ value (torch.Tensor): The value at which to evaluate the log probability.
+
+ Returns:
+ torch.Tensor: The log probability density at `value`.
+ """
+ var = self.scale**2
+ log_scale = self.scale.log() if not isinstance(self.scale, Real) else math.log(self.scale)
+ return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
+
+ def _transform(self, z: torch.Tensor) -> torch.Tensor:
+ """
+ Transforms an input tensor `z` to a standardized form based on the mean and scale.
+
+ Args:
+ z (torch.Tensor): Input tensor to transform.
+
+ Returns:
+ torch.Tensor: The transformed tensor, representing the standardized normal form.
+ """
+ return (z - self.loc) / self.scale
+
+ def _d_transform_d_z(self) -> torch.Tensor:
+ """
+ Computes the derivative of the transform function with respect to `z`.
+
+ Returns:
+ torch.Tensor: The reciprocal of the scale, representing the gradient for reparameterization.
+ """
+ return 1 / self.scale
+
+ def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
+ """
+ Generates a sample from the Normal distribution using `torch.normal`.
+
+ Args:
+ sample_shape (torch.Size): Shape of the sample to generate.
+
+ Returns:
+ torch.Tensor: A tensor with samples from the Normal distribution, detached from the computation graph.
+ """
+ shape = self._extended_shape(sample_shape)
+ with torch.no_grad():
+ return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
+
+ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation.
+
+ Returns:
+ torch.Tensor: A tensor containing a reparameterized sample, useful for gradient-based optimization.
+ """
+ # Sample a point from the distribution
+ x = self.sample(sample_shape)
+ # Transform the sample to standard normal form
+ transform = self._transform(x)
+ # Compute a surrogate value for backpropagation
+ surrogate_x = -transform / self._d_transform_d_z().detach()
+ # Return the sample with gradient tracking enabled
+ return x + (surrogate_x - surrogate_x.detach())
+
+
+class MixtureSameFamily(torch.distributions.MixtureSameFamily):
+ """
+ Represents a mixture of distributions from the same family.
+ Supporting reparameterized sampling for gradient-based optimization.
+ """
+
+ has_rsample = True
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Initializes the MixtureSameFamily distribution and checks if the component distributions.
+ Support reparameterized sampling (required for `rsample`).
+
+ Raises:
+ ValueError: If the component distributions do not support reparameterized sampling.
+ """
+ super().__init__(*args, **kwargs)
+ if not self._component_distribution.has_rsample:
+ raise ValueError("Cannot reparameterize a mixture of non-reparameterizable components.")
+
+ # Define a list of discrete distributions for checking in `_log_cdf`
+ self.discrete_distributions: List[Distribution] = [
+ Bernoulli,
+ Binomial,
+ ContinuousBernoulli,
+ Geometric,
+ NegativeBinomial,
+ RelaxedBernoulli,
+ ]
+
+ def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
+ """
+ Generates a reparameterized sample from the mixture of distributions.
+
+ This method generates a sample, applies a distributional transformation,
+ and computes a surrogate sample to enable gradient flow during optimization.
+
+ Args:
+ sample_shape (torch.Size): The shape of the sample to generate.
+
+ Returns:
+ torch.Tensor: A reparameterized sample with gradients enabled.
+ """
+ # Generate a sample from the mixture distribution
+ x = self.sample(sample_shape=sample_shape)
+ event_size = math.prod(self.event_shape)
+
+ if event_size != 1:
+ # For multi-dimensional events, use reshaped distributional transformations
+ def reshaped_dist_trans(input_x: torch.Tensor) -> torch.Tensor:
+ return torch.reshape(self._distributional_transform(input_x), (-1, event_size))
+
+ def reshaped_dist_trans_summed(x_2d: torch.Tensor) -> torch.Tensor:
+ return torch.sum(reshaped_dist_trans(x_2d), dim=0)
+
+ x_2d = x.reshape((-1, event_size))
+ transform_2d = reshaped_dist_trans(x)
+ jac = jacobian(reshaped_dist_trans_summed, x_2d).detach().movedim(1, 0)
+ surrogate_x_2d = -torch.linalg.solve_triangular(jac.detach(), transform_2d[..., None], upper=False)
+ surrogate_x = surrogate_x_2d.reshape(x.shape)
+ else:
+ # For one-dimensional events, apply the standard distributional transformation
+ transform = self._distributional_transform(x)
+ log_prob_x = self.log_prob(x)
+
+ if self._event_ndims > 1:
+ log_prob_x = log_prob_x.reshape(log_prob_x.shape + (1,) * self._event_ndims)
+
+ surrogate_x = -transform * torch.exp(-log_prob_x.detach())
+
+ return x + (surrogate_x - surrogate_x.detach())
+
+ def _distributional_transform(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Applies a distributional transformation to the input sample `x`, using cumulative
+ distribution functions (CDFs) and posterior weights.
+
+ Args:
+ x (torch.Tensor): The input sample to transform.
+
+ Returns:
+ torch.Tensor: The transformed tensor based on the mixture model's CDFs.
+ """
+ if isinstance(self._component_distribution, torch.distributions.Independent):
+ univariate_components = self._component_distribution.base_dist
+ else:
+ univariate_components = self._component_distribution
+
+ # Expand input tensor and compute log-probabilities in each component
+ x = self._pad(x) # [S, B, 1, E]
+ log_prob_x = univariate_components.log_prob(x) # [S, B, K, E]
+
+ event_size = math.prod(self.event_shape)
+
+ if event_size != 1:
+ # CDF transformation for multi-dimensional events
+ cumsum_log_prob_x = log_prob_x.reshape(-1, event_size)
+ cumsum_log_prob_x = torch.cumsum(cumsum_log_prob_x, dim=-1)
+ cumsum_log_prob_x = cumsum_log_prob_x.roll(shifts=1, dims=-1)
+ cumsum_log_prob_x[:, 0] = 0
+ cumsum_log_prob_x = cumsum_log_prob_x.reshape(log_prob_x.shape)
+
+ logits_mix_prob = self._pad_mixture_dimensions(self._mixture_distribution.logits)
+ log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x
+
+ component_axis = -self._event_ndims - 1
+ cdf_x = univariate_components.cdf(x)
+ posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=component_axis)
+ else:
+ # CDF transformation for one-dimensional events
+ log_posterior_weights_x = self._mixture_distribution.logits
+ component_axis = -self._event_ndims - 1
+ cdf_x = univariate_components.cdf(x)
+ posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=-1)
+ posterior_weights_x = self._pad_mixture_dimensions(posterior_weights_x)
+
+ return torch.sum(posterior_weights_x * cdf_x, dim=component_axis)
+
+ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the logarithm of the cumulative distribution function (CDF) for the mixture distribution.
+
+ Args:
+ x (torch.Tensor): The input tensor for which to compute the log CDF.
+
+ Returns:
+ torch.Tensor: The log CDF values.
+ """
+ x = self._pad(x)
+ if isinstance(self._component_distribution, torch.distributions.Independent):
+ univariate_components = self._component_distribution.base_dist
+ else:
+ univariate_components = self._component_distribution
+
+ if callable(getattr(univariate_components, "_log_cdf", None)):
+ log_cdf_x = univariate_components._log_cdf(x)
+ else:
+ log_cdf_x = torch.log(univariate_components.cdf(x))
+
+ if isinstance(univariate_components, tuple(self.discrete_distributions)):
+ log_mix_prob = torch.sigmoid(self._mixture_distribution.logits)
+ else:
+ log_mix_prob = F.log_softmax(self._mixture_distribution.logits, dim=-1)
+
+ return torch.logsumexp(log_cdf_x + log_mix_prob, dim=-1)
+
+
+def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
+ """
+ Evaluate a polynomial at given points.
+
+ Args:
+ y: Input tensor.
+ coeffs: Polynomial coefficients.
+
+ Returns:
+ Evaluated polynomial tensor.
+ """
+ coef = list(coef)
+ result = coef.pop()
+ while coef:
+ result = coef.pop() + y * result
+ return result
+
+
+_I0_COEF_SMALL = [
+ 1.0,
+ 3.5156229,
+ 3.0899424,
+ 1.2067492,
+ 0.2659732,
+ 0.360768e-1,
+ 0.45813e-2,
+]
+_I0_COEF_LARGE = [
+ 0.39894228,
+ 0.1328592e-1,
+ 0.225319e-2,
+ -0.157565e-2,
+ 0.916281e-2,
+ -0.2057706e-1,
+ 0.2635537e-1,
+ -0.1647633e-1,
+ 0.392377e-2,
+]
+_I1_COEF_SMALL = [
+ 0.5,
+ 0.87890594,
+ 0.51498869,
+ 0.15084934,
+ 0.2658733e-1,
+ 0.301532e-2,
+ 0.32411e-3,
+]
+_I1_COEF_LARGE = [
+ 0.39894228,
+ -0.3988024e-1,
+ -0.362018e-2,
+ 0.163801e-2,
+ -0.1031555e-1,
+ 0.2282967e-1,
+ -0.2895312e-1,
+ 0.1787654e-1,
+ -0.420059e-2,
+]
+
+_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
+_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
+
+
+def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
+ """
+ Compute the logarithm of the modified Bessel function of the first kind.
+
+ Args:
+ x: Input tensor, must be positive.
+ order: Order of the Bessel function (0 or 1).
+
+ Returns:
+ Logarithm of the Bessel function.
+ """
+ if order not in {0, 1}:
+ raise ValueError("Order must be 0 or 1.")
+
+ # compute small solution
+ y = x / 3.75
+ y = y * y
+ small = _eval_poly(y, _COEF_SMALL[order])
+ if order == 1:
+ small = x.abs() * small
+ small = small.log()
+
+ # compute large solution
+ y = 3.75 / x
+ large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
+
+ result = torch.where(x < 3.75, small, large)
+ return result
+
+
+@torch.jit.script_if_tracing
+def _rejection_sample(
+ loc: torch.Tensor,
+ concentration: torch.Tensor,
+ proposal_r: torch.Tensor,
+ x: torch.Tensor
+) -> torch.Tensor:
+ """
+ Perform rejection sampling for the von Mises distribution.
+
+ Args:
+ loc: Location parameter.
+ concentration: Concentration parameter.
+ proposal_r: Precomputed proposal parameter.
+ x: Tensor to fill with samples.
+
+ Returns:
+ Tensor of samples.
+ """
+ done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
+ while not done.all():
+ u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
+ u1, u2, u3 = u.unbind()
+ z = torch.cos(math.pi * u1)
+ f = (1 + proposal_r * z) / (proposal_r + z)
+ c = concentration * (proposal_r - f)
+ accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
+ if accept.any():
+ x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
+ done = done | accept
+ return (x + math.pi + loc) % (2 * math.pi) - math.pi
+
+
+class VonMises(Distribution):
+ """
+ Von Mises distribution class for circular data.
+ """
+
+ arg_constraints = {
+ "loc": constraints.real,
+ "concentration": constraints.positive,
+ }
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(
+ self,
+ loc: torch.Tensor,
+ concentration: torch.Tensor,
+ validate_args: bool = None,
+ ):
+ self.loc, self.concentration = broadcast_all(loc, concentration)
+ batch_shape = self.loc.shape
+ super().__init__(batch_shape, torch.Size(), validate_args)
+
+ @lazy_property
+ @torch.no_grad()
+ def _proposal_r(self) -> torch.Tensor:
+ """
+ Compute the proposal parameter for sampling.
+ """
+ kappa = self._concentration
+ tau = 1 + (1 + 4 * kappa**2).sqrt()
+ rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
+ _proposal_r = (1 + rho**2) / (2 * rho)
+
+ # second order Taylor expansion around 0 for small kappa
+ _proposal_r_taylor = 1 / kappa + kappa
+ return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
+
+ def log_prob(self, value):
+ """
+ Compute the log probability of the given value.
+
+ Args:
+ value: Tensor of values.
+
+ Returns:
+ Tensor of log probabilities.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ log_prob = self.concentration * torch.cos(value - self.loc)
+ log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
+ return log_prob
+
+ @lazy_property
+ def _loc(self):
+ return self.loc.to(torch.double)
+
+ @lazy_property
+ def _concentration(self):
+ return self.concentration.to(torch.double)
+
+ @torch.no_grad()
+ def sample(self, sample_shape=torch.Size()):
+ """
+ The sampling algorithm for the von Mises distribution is based on the
+ following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
+ von Mises distribution." Applied Statistics (1979): 152-157.
+
+ Sampling is always done in double precision internally to avoid a hang
+ in _rejection_sample() for small values of the concentration, which
+ starts to happen for single precision around 1e-4 (see issue #88443).
+ """
+ shape = self._extended_shape(sample_shape)
+ x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
+ return _rejection_sample(
+ self._loc, self._concentration, self._proposal_r, x
+ ).to(self.loc.dtype)
+
+ def rsample(self, sample_shape=torch.Size()):
+ """
+ Generate reparameterized samples from the distribution.
+ """
+ shape = self._extended_shape(sample_shape)
+ samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape)
+ samples = samples + self.loc
+
+ # Map the samples to [-pi, pi].
+ return samples - 2. * torch.pi * torch.round(samples / (2. * torch.pi))
+
+ @property
+ def mean(self):
+ """Mean of the distribution."""
+ return self.loc
+
+ @property
+ def variance(self):
+ """Variance of the distribution."""
+ return 1 - (
+ _log_modified_bessel_fn(self.concentration, order=1)
+ - _log_modified_bessel_fn(self.concentration, order=0)
+ ).exp()
+
+@torch.jit.script_if_tracing
+@torch.no_grad()
+def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, shape: torch.Size) -> torch.Tensor:
+ """
+ Perform rejection sampling to draw samples from the von Mises distribution.
+
+ Args:
+ concentration (torch.Tensor): Concentration parameter (kappa) of the distribution.
+ proposal_r (torch.Tensor): Proposal distribution parameter.
+ shape (torch.Size): Desired shape of the samples.
+
+ Returns:
+ torch.Tensor: Samples from the von Mises distribution.
+ """
+ x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device)
+ done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device)
+
+ while not done.all():
+ u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device)
+ u1, u2, u3 = u.unbind()
+ z = torch.cos(math.pi * u1)
+ f = (1 + proposal_r * z) / (proposal_r + z)
+ c = concentration * (proposal_r - f)
+ accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
+ if accept.any():
+ x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
+ done = done | accept
+ return x
+
+def cosxm1(x: torch.Tensor) -> torch.Tensor:
+ """
+ Compute cos(x) - 1 using a numerically stable formula.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor, `cos(x) - 1`.
+ """
+ return -2 * torch.square(torch.sin(x / 2.0))
+
+class _VonMisesSampler(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ concentration: torch.Tensor,
+ proposal_r: torch.Tensor,
+ shape: torch.Size,
+ ) -> torch.Tensor:
+ """
+ Perform forward sampling using rejection sampling.
+
+ Args:
+ ctx (torch.autograd.function.FunctionCtx): Context object for saving tensors.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ proposal_r (torch.Tensor): Proposal distribution parameter.
+ shape (torch.Size): Desired shape of the samples.
+
+ Returns:
+ torch.Tensor: Samples from the von Mises distribution.
+ """
+ samples = _rejection_rsample(concentration, proposal_r, shape)
+ ctx.save_for_backward(concentration, proposal_r, samples)
+
+ return samples
+
+ @staticmethod
+ @torch.autograd.function.once_differentiable
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
+ ) -> Tuple[torch.Tensor, None, None]:
+ """
+ Compute gradients for backward pass using implicit reparameterization.
+
+ Args:
+ ctx (torch.autograd.function.FunctionCtx): Context object containing saved tensors.
+ grad_output (torch.Tensor): Gradient of the loss with respect to the output.
+
+ Returns:
+ Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
+ """
+ concentration, proposal_r, samples = ctx.saved_tensors
+
+ num_periods = torch.round(samples / (2. * torch.pi))
+ x_mapped = samples - (2. * torch.pi) * num_periods
+
+ ## Parameters from the paper
+ ck = 10.5
+ num_terms = 20
+
+ ## Compute series and normal approximation
+ cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms)
+ cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration)
+ use_series = concentration < ck
+ cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
+ dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal)
+
+ ## Compute CDF gradient terms
+ inv_prob = torch.exp(concentration * cosxm1(samples)) / (
+ 2 * math.pi * torch.special.i0e(concentration)
+ )
+ grad_concentration = grad_output*(-dcdf_dconcentration / inv_prob)
+
+ return grad_concentration, None, None
+
+
+def von_mises_cdf_series(
+ x: torch.Tensor, concentration: torch.Tensor, num_terms: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the CDF of the von Mises distribution using a series approximation.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ num_terms (int): Number of terms in the series.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
+ """
+ vn = torch.zeros_like(x)
+ dvn_dconcentration = torch.zeros_like(x)
+
+ n = torch.tensor(num_terms, dtype=x.dtype, device=x.device)
+ rn = torch.zeros_like(x)
+ drn_dconcentration = torch.zeros_like(x)
+
+ while n > 0:
+ denominator = 2. * n / concentration + rn
+ ddenominator_dk = -2. * n / concentration ** 2 + drn_dconcentration
+ rn = 1. / denominator
+ drn_dconcentration = -ddenominator_dk / denominator ** 2
+
+ multiplier = torch.sin(n * x) / n + vn
+ vn = rn * multiplier
+ dvn_dconcentration = (drn_dconcentration * multiplier + rn * dvn_dconcentration)
+
+ n -= 1
+
+ cdf = 0.5 + x / (2. * torch.pi) + vn / torch.pi
+ dcdf_dconcentration = dvn_dconcentration / torch.pi
+
+ cdf_clipped = torch.clamp(cdf, 0., 1.)
+ dcdf_dconcentration *= (cdf >= 0.) & (cdf <= 1.)
+
+ return cdf_clipped, dcdf_dconcentration
+
+def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """
+ Approximate the CDF of the von Mises distribution.
+
+ Args:
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Approximate CDF values.
+ """
+
+ # Calculate the z value based on the approximation
+ z = (torch.sqrt(torch.tensor(2. / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x)
+ # Apply corrections to z to improve the approximation
+ z2 = z ** 2
+ z3 = z2 * z
+ z4 = z2 ** 2
+ c = 24. * concentration
+ c1 = 56.
+
+ xi = z - z3 / (
+ ((c - 2. * z2 - 16.) / 3.) -
+ (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.)
+ ) ** 2
+
+ # Use the standard normal distribution for the approximation
+ distrib = torch.distributions.Normal(
+ torch.tensor(0., dtype=x.dtype, device=x.device),
+ torch.tensor(1., dtype=x.dtype, device=x.device)
+ )
+
+ return distrib.cdf(xi)
+
+def von_mises_cdf_normal(
+ x: torch.Tensor, concentration: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the CDF of the von Mises distribution using a normal approximation.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
+ """
+ with torch.enable_grad():
+ concentration_ = concentration.detach().clone().requires_grad_(True)
+ cdf = cdf_func(concentration_, x)
+ cdf.backward(torch.ones_like(cdf)) # Compute gradients
+ dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient
+ # Detach gradients to prevent further autograd tracking
+ concentration_.grad = None
+ return cdf, dcdf_dconcentration
\ No newline at end of file
diff --git a/src/irt/distributions.py b/src/irt/distributions.py
index 1511abf..43197d9 100644
--- a/src/irt/distributions.py
+++ b/src/irt/distributions.py
@@ -17,8 +17,9 @@
constraints,
)
from torch.distributions.exp_family import ExponentialFamily
-from torch.distributions.utils import broadcast_all
+from torch.distributions.utils import broadcast_all, lazy_property
from torch.types import _size
+from torch.distributions.distribution import Distribution
default_size = torch.Size()
@@ -1035,3 +1036,435 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
log_mix_prob = F.log_softmax(self._mixture_distribution.logits, dim=-1)
return torch.logsumexp(log_cdf_x + log_mix_prob, dim=-1)
+
+
+def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
+ """
+ Evaluate a polynomial at given points.
+
+ Args:
+ y: Input tensor.
+ coeffs: Polynomial coefficients.
+
+ Returns:
+ Evaluated polynomial tensor.
+ """
+ coef = list(coef)
+ result = coef.pop()
+ while coef:
+ result = coef.pop() + y * result
+ return result
+
+
+_I0_COEF_SMALL = [
+ 1.0,
+ 3.5156229,
+ 3.0899424,
+ 1.2067492,
+ 0.2659732,
+ 0.360768e-1,
+ 0.45813e-2,
+]
+_I0_COEF_LARGE = [
+ 0.39894228,
+ 0.1328592e-1,
+ 0.225319e-2,
+ -0.157565e-2,
+ 0.916281e-2,
+ -0.2057706e-1,
+ 0.2635537e-1,
+ -0.1647633e-1,
+ 0.392377e-2,
+]
+_I1_COEF_SMALL = [
+ 0.5,
+ 0.87890594,
+ 0.51498869,
+ 0.15084934,
+ 0.2658733e-1,
+ 0.301532e-2,
+ 0.32411e-3,
+]
+_I1_COEF_LARGE = [
+ 0.39894228,
+ -0.3988024e-1,
+ -0.362018e-2,
+ 0.163801e-2,
+ -0.1031555e-1,
+ 0.2282967e-1,
+ -0.2895312e-1,
+ 0.1787654e-1,
+ -0.420059e-2,
+]
+
+_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
+_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
+
+
+def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
+ """
+ Compute the logarithm of the modified Bessel function of the first kind.
+
+ Args:
+ x: Input tensor, must be positive.
+ order: Order of the Bessel function (0 or 1).
+
+ Returns:
+ Logarithm of the Bessel function.
+ """
+ if order not in {0, 1}:
+ raise ValueError("Order must be 0 or 1.")
+
+ # compute small solution
+ y = x / 3.75
+ y = y * y
+ small = _eval_poly(y, _COEF_SMALL[order])
+ if order == 1:
+ small = x.abs() * small
+ small = small.log()
+
+ # compute large solution
+ y = 3.75 / x
+ large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
+
+ result = torch.where(x < 3.75, small, large)
+ return result
+
+
+@torch.jit.script_if_tracing
+def _rejection_sample(
+ loc: torch.Tensor,
+ concentration: torch.Tensor,
+ proposal_r: torch.Tensor,
+ x: torch.Tensor
+) -> torch.Tensor:
+ """
+ Perform rejection sampling for the von Mises distribution.
+
+ Args:
+ loc: Location parameter.
+ concentration: Concentration parameter.
+ proposal_r: Precomputed proposal parameter.
+ x: Tensor to fill with samples.
+
+ Returns:
+ Tensor of samples.
+ """
+ done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
+ while not done.all():
+ u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
+ u1, u2, u3 = u.unbind()
+ z = torch.cos(math.pi * u1)
+ f = (1 + proposal_r * z) / (proposal_r + z)
+ c = concentration * (proposal_r - f)
+ accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
+ if accept.any():
+ x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
+ done = done | accept
+ return (x + math.pi + loc) % (2 * math.pi) - math.pi
+
+
+class VonMises(Distribution):
+ """
+ Von Mises distribution class for circular data.
+ """
+
+ arg_constraints = {
+ "loc": constraints.real,
+ "concentration": constraints.positive,
+ }
+ support = constraints.real
+ has_rsample = True
+
+ def __init__(
+ self,
+ loc: torch.Tensor,
+ concentration: torch.Tensor,
+ validate_args: bool = None,
+ ):
+ self.loc, self.concentration = broadcast_all(loc, concentration)
+ batch_shape = self.loc.shape
+ super().__init__(batch_shape, torch.Size(), validate_args)
+
+ @lazy_property
+ @torch.no_grad()
+ def _proposal_r(self) -> torch.Tensor:
+ """
+ Compute the proposal parameter for sampling.
+ """
+ kappa = self._concentration
+ tau = 1 + (1 + 4 * kappa**2).sqrt()
+ rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
+ _proposal_r = (1 + rho**2) / (2 * rho)
+
+ # second order Taylor expansion around 0 for small kappa
+ _proposal_r_taylor = 1 / kappa + kappa
+ return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
+
+ def log_prob(self, value):
+ """
+ Compute the log probability of the given value.
+
+ Args:
+ value: Tensor of values.
+
+ Returns:
+ Tensor of log probabilities.
+ """
+ if self._validate_args:
+ self._validate_sample(value)
+ log_prob = self.concentration * torch.cos(value - self.loc)
+ log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
+ return log_prob
+
+ @lazy_property
+ def _loc(self):
+ return self.loc.to(torch.double)
+
+ @lazy_property
+ def _concentration(self):
+ return self.concentration.to(torch.double)
+
+ @torch.no_grad()
+ def sample(self, sample_shape=torch.Size()):
+ """
+ The sampling algorithm for the von Mises distribution is based on the
+ following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
+ von Mises distribution." Applied Statistics (1979): 152-157.
+
+ Sampling is always done in double precision internally to avoid a hang
+ in _rejection_sample() for small values of the concentration, which
+ starts to happen for single precision around 1e-4 (see issue #88443).
+ """
+ shape = self._extended_shape(sample_shape)
+ x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
+ return _rejection_sample(
+ self._loc, self._concentration, self._proposal_r, x
+ ).to(self.loc.dtype)
+
+ def rsample(self, sample_shape=torch.Size()):
+ """
+ Generate reparameterized samples from the distribution.
+ """
+ shape = self._extended_shape(sample_shape)
+ samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape)
+ samples = samples + self.loc
+
+ # Map the samples to [-pi, pi].
+ return samples - 2. * torch.pi * torch.round(samples / (2. * torch.pi))
+
+ @property
+ def mean(self):
+ """Mean of the distribution."""
+ return self.loc
+
+ @property
+ def variance(self):
+ """Variance of the distribution."""
+ return 1 - (
+ _log_modified_bessel_fn(self.concentration, order=1)
+ - _log_modified_bessel_fn(self.concentration, order=0)
+ ).exp()
+
+@torch.jit.script_if_tracing
+@torch.no_grad()
+def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, shape: torch.Size) -> torch.Tensor:
+ """
+ Perform rejection sampling to draw samples from the von Mises distribution.
+
+ Args:
+ concentration (torch.Tensor): Concentration parameter (kappa) of the distribution.
+ proposal_r (torch.Tensor): Proposal distribution parameter.
+ shape (torch.Size): Desired shape of the samples.
+
+ Returns:
+ torch.Tensor: Samples from the von Mises distribution.
+ """
+ x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device)
+ done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device)
+
+ while not done.all():
+ u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device)
+ u1, u2, u3 = u.unbind()
+ z = torch.cos(math.pi * u1)
+ f = (1 + proposal_r * z) / (proposal_r + z)
+ c = concentration * (proposal_r - f)
+ accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
+ if accept.any():
+ x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
+ done = done | accept
+ return x
+
+def cosxm1(x: torch.Tensor) -> torch.Tensor:
+ """
+ Compute cos(x) - 1 using a numerically stable formula.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor, `cos(x) - 1`.
+ """
+ return -2 * torch.square(torch.sin(x / 2.0))
+
+class _VonMisesSampler(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ concentration: torch.Tensor,
+ proposal_r: torch.Tensor,
+ shape: torch.Size,
+ ) -> torch.Tensor:
+ """
+ Perform forward sampling using rejection sampling.
+
+ Args:
+ ctx (torch.autograd.function.FunctionCtx): Context object for saving tensors.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ proposal_r (torch.Tensor): Proposal distribution parameter.
+ shape (torch.Size): Desired shape of the samples.
+
+ Returns:
+ torch.Tensor: Samples from the von Mises distribution.
+ """
+ samples = _rejection_rsample(concentration, proposal_r, shape)
+ ctx.save_for_backward(concentration, proposal_r, samples)
+
+ return samples
+
+ @staticmethod
+ @torch.autograd.function.once_differentiable
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
+ ) -> Tuple[torch.Tensor, None, None]:
+ """
+ Compute gradients for backward pass using implicit reparameterization.
+
+ Args:
+ ctx (torch.autograd.function.FunctionCtx): Context object containing saved tensors.
+ grad_output (torch.Tensor): Gradient of the loss with respect to the output.
+
+ Returns:
+ Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
+ """
+ concentration, proposal_r, samples = ctx.saved_tensors
+
+ num_periods = torch.round(samples / (2. * torch.pi))
+ x_mapped = samples - (2. * torch.pi) * num_periods
+
+ ## Parameters from the paper
+ ck = 10.5
+ num_terms = 20
+
+ ## Compute series and normal approximation
+ cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms)
+ cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration)
+ use_series = concentration < ck
+ cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
+ dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal)
+
+ ## Compute CDF gradient terms
+ inv_prob = torch.exp(concentration * cosxm1(samples)) / (
+ 2 * math.pi * torch.special.i0e(concentration)
+ )
+ grad_concentration = grad_output*(-dcdf_dconcentration / inv_prob)
+
+ return grad_concentration, None, None
+
+
+def von_mises_cdf_series(
+ x: torch.Tensor, concentration: torch.Tensor, num_terms: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the CDF of the von Mises distribution using a series approximation.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ num_terms (int): Number of terms in the series.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
+ """
+ vn = torch.zeros_like(x)
+ dvn_dconcentration = torch.zeros_like(x)
+
+ n = torch.tensor(num_terms, dtype=x.dtype, device=x.device)
+ rn = torch.zeros_like(x)
+ drn_dconcentration = torch.zeros_like(x)
+
+ while n > 0:
+ denominator = 2. * n / concentration + rn
+ ddenominator_dk = -2. * n / concentration ** 2 + drn_dconcentration
+ rn = 1. / denominator
+ drn_dconcentration = -ddenominator_dk / denominator ** 2
+
+ multiplier = torch.sin(n * x) / n + vn
+ vn = rn * multiplier
+ dvn_dconcentration = (drn_dconcentration * multiplier + rn * dvn_dconcentration)
+
+ n -= 1
+
+ cdf = 0.5 + x / (2. * torch.pi) + vn / torch.pi
+ dcdf_dconcentration = dvn_dconcentration / torch.pi
+
+ cdf_clipped = torch.clamp(cdf, 0., 1.)
+ dcdf_dconcentration *= (cdf >= 0.) & (cdf <= 1.)
+
+ return cdf_clipped, dcdf_dconcentration
+
+def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """
+ Approximate the CDF of the von Mises distribution.
+
+ Args:
+ concentration (torch.Tensor): Concentration parameter (kappa).
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Approximate CDF values.
+ """
+
+ # Calculate the z value based on the approximation
+ z = (torch.sqrt(torch.tensor(2. / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x)
+ # Apply corrections to z to improve the approximation
+ z2 = z ** 2
+ z3 = z2 * z
+ z4 = z2 ** 2
+ c = 24. * concentration
+ c1 = 56.
+
+ xi = z - z3 / (
+ ((c - 2. * z2 - 16.) / 3.) -
+ (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.)
+ ) ** 2
+
+ # Use the standard normal distribution for the approximation
+ distrib = torch.distributions.Normal(
+ torch.tensor(0., dtype=x.dtype, device=x.device),
+ torch.tensor(1., dtype=x.dtype, device=x.device)
+ )
+
+ return distrib.cdf(xi)
+
+def von_mises_cdf_normal(
+ x: torch.Tensor, concentration: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the CDF of the von Mises distribution using a normal approximation.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ concentration (torch.Tensor): Concentration parameter (kappa).
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
+ """
+ with torch.enable_grad():
+ concentration_ = concentration.detach().clone().requires_grad_(True)
+ cdf = cdf_func(concentration_, x)
+ cdf.backward(torch.ones_like(cdf)) # Compute gradients
+ dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient
+ # Detach gradients to prevent further autograd tracking
+ concentration_.grad = None
+ return cdf, dcdf_dconcentration
\ No newline at end of file