diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000..f0dfbda --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,75 @@ +# Implicit Reparametrization Trick + +
+ +
+ + + + + + + + + + + + + + + +
Title Implicit Reparametrization Trick for BMM
Authors Matvei Kreinin, Maria Nikitina, Petr Babkin, Iryna Zabarianska
Consultant Oleg Bakhteev, PhD
+ + +![Testing](https://github.com/intsystems/implicit-reparameterization-trick/actions/workflows/testing.yml/badge.svg) +![Docs](https://github.com/intsystems/implicit-reparameterization-trick/actions/workflows/docs.yml/badge.svg) + + +## Description + +This repository implements an educational project for the Bayesian Multimodeling course. It implements algorithms for sampling from various distributions, using the implicit reparameterization trick. + +## Scope +We plan to implement the following distributions in our library: +- [x] Gaussian normal distribution (*) +- [x] Dirichlet distribution (Beta distributions)(\*) +- [x] Mixture of the same family distributions (**) +- [x] Student's t-distribution (**) (\*) +- [x] VonMises distribution (***) +- [ ] Sampling from an arbitrary factorized distribution (***) + +(\*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison + +(\*\*) - this distribution is added as a backup, their inclusion is questionable + +(\*\*\*) - this distribution is not very clear in implementation, its inclusion is questionable + +## Stack + +We plan to inherit from the torch.distribution.Distribution class, so we need to implement all the methods that are present in that class. + +## Usage +In this example, we demonstrate the application of our library using a Variational Autoencoder (VAE) model, where the latent layer is modified by a normal distribution. +``` +>>> import torch.distributions.implicit as irt +>>> params = Encoder(inputs) +>>> gauss = irt.Normal(*params) +>>> deviated = gauss.rsample() +>>> outputs = Decoder(deviated) +``` +In this example, we demonstrate the use of a mixture of distributions using our library. +``` +>>> import irt +>>> params = Encoder(inputs) +>>> mix = irt.Mixture([irt.Normal(*params), irt.Dirichlet(*params)]) +>>> deviated = mix.rsample() +>>> outputs = Decoder(deviated) +``` + +## Links +- [LinkReview](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/linkreview.md) +- [Plan of project](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/planning.md) +- [BlogPost](blogpost/Blog_post_sketch.pdf) +- [Documentation](https://intsystems.github.io/implicit-reparameterization-trick/) +- [Matvei Kreinin](https://github.com/kreininmv), [Maria Nikitina](https://github.com/NikitinaMaria), [Petr Babkin](https://github.com/petr-parker), [Iryna Zabarianska](https://github.com/Akshiira) + diff --git a/README.md b/README.md index 38f922a..f0dfbda 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@ # Implicit Reparametrization Trick +
+ +
+ + @@ -15,6 +20,7 @@
Title
+ ![Testing](https://github.com/intsystems/implicit-reparameterization-trick/actions/workflows/testing.yml/badge.svg) ![Docs](https://github.com/intsystems/implicit-reparameterization-trick/actions/workflows/docs.yml/badge.svg) @@ -25,11 +31,12 @@ This repository implements an educational project for the Bayesian Multimodeling ## Scope We plan to implement the following distributions in our library: -- Gaussian normal distribution (*) -- Dirichlet distribution (Beta distributions)(\*) -- Sampling from a mixture of distributions -- Sampling from the Student's t-distribution (**) (\*) -- Sampling from an arbitrary factorized distribution (***) +- [x] Gaussian normal distribution (*) +- [x] Dirichlet distribution (Beta distributions)(\*) +- [x] Mixture of the same family distributions (**) +- [x] Student's t-distribution (**) (\*) +- [x] VonMises distribution (***) +- [ ] Sampling from an arbitrary factorized distribution (***) (\*) - this distribution is already implemented in torch using the explicit reparameterization trick, we will implement it for comparison @@ -64,4 +71,5 @@ In this example, we demonstrate the use of a mixture of distributions using our - [Plan of project](https://github.com/intsystems/implitic-reparametrization-trick/blob/main/planning.md) - [BlogPost](blogpost/Blog_post_sketch.pdf) - [Documentation](https://intsystems.github.io/implicit-reparameterization-trick/) +- [Matvei Kreinin](https://github.com/kreininmv), [Maria Nikitina](https://github.com/NikitinaMaria), [Petr Babkin](https://github.com/petr-parker), [Iryna Zabarianska](https://github.com/Akshiira) diff --git a/code/.ipynb_checkpoints/run_unittest-checkpoint.py b/code/.ipynb_checkpoints/run_unittest-checkpoint.py new file mode 100644 index 0000000..bca15b4 --- /dev/null +++ b/code/.ipynb_checkpoints/run_unittest-checkpoint.py @@ -0,0 +1,444 @@ +import unittest +import math +import torch +import sys +sys.path.append('../src') +from irt.distributions import Normal, Gamma, MixtureSameFamily, Beta, Dirichlet, StudentT +from torch.distributions import Categorical, Independent + + +class TestNormalDistribution(unittest.TestCase): + def setUp(self): + self.loc = torch.tensor([0.0, 1.0]).requires_grad_(True) + self.scale = torch.tensor([1.0, 2.0]).requires_grad_(True) + self.normal = Normal(self.loc, self.scale) + + def test_init(self): + normal = Normal(0.0, 1.0) + self.assertEqual(normal.loc, 0.0) + self.assertEqual(normal.scale, 1.0) + self.assertEqual(normal.batch_shape, torch.Size()) + + normal = Normal(torch.tensor([0.0, 1.0]), torch.tensor([1.0, 2.0])) + self.assertTrue(torch.equal(normal.loc, torch.tensor([0.0, 1.0]))) + self.assertTrue(torch.equal(normal.scale, torch.tensor([1.0, 2.0]))) + self.assertEqual(normal.batch_shape, torch.Size([2])) + + def test_properties(self): + self.assertTrue(torch.equal(self.normal.mean, self.loc)) + self.assertTrue(torch.equal(self.normal.mode, self.loc)) + self.assertTrue(torch.equal(self.normal.stddev, self.scale)) + self.assertTrue(torch.equal(self.normal.variance, self.scale**2)) + + def test_entropy(self): + entropy = self.normal.entropy() + expected_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) + self.assertTrue(torch.allclose(entropy, expected_entropy)) + + def test_cdf(self): + value = torch.tensor([0.0, 2.0]) + cdf = self.normal.cdf(value) + expected_cdf = 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2)))) + self.assertTrue(torch.allclose(cdf, expected_cdf)) + + def test_expand(self): + expanded_normal = self.normal.expand(torch.Size([3, 2])) + self.assertEqual(expanded_normal.batch_shape, torch.Size([3, 2])) + self.assertTrue(torch.equal(expanded_normal.loc, self.loc.expand([3, 2]))) + self.assertTrue(torch.equal(expanded_normal.scale, self.scale.expand([3, 2]))) + + def test_icdf(self): + value = torch.tensor([0.2, 0.8]) + icdf = self.normal.icdf(value) + expected_icdf = self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + self.assertTrue(torch.allclose(icdf, expected_icdf)) + + def test_log_prob(self): + value = torch.tensor([0.0, 2.0]) + log_prob = self.normal.log_prob(value) + var = self.scale**2 + log_scale = self.scale.log() + expected_log_prob = -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + self.assertTrue(torch.allclose(log_prob, expected_log_prob)) + + def test_sample(self): + samples = self.normal.sample(sample_shape=torch.Size([100])) + self.assertEqual(samples.shape, torch.Size([100, 2])) # Check shape + emperic_mean = samples.mean(dim=0) + self.assertTrue((emperic_mean < self.normal.mean + self.normal.scale).all()) + self.assertTrue((self.normal.mean - self.normal.scale < emperic_mean).all()) + + def test_rsample(self): + samples = self.normal.rsample(sample_shape=torch.Size([10])) + self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape + self.assertTrue(samples.requires_grad) # Check gradient tracking + + +class TestGammaDistribution(unittest.TestCase): + def setUp(self): + self.concentration = torch.tensor([1.0, 2.0]).requires_grad_(True) + self.rate = torch.tensor([1.0, 0.5]).requires_grad_(True) + self.gamma = Gamma(self.concentration, self.rate) + + def test_init(self): + gamma = Gamma(1.0, 1.0) + self.assertEqual(gamma.concentration, 1.0) + self.assertEqual(gamma.rate, 1.0) + self.assertEqual(gamma.batch_shape, torch.Size()) + + gamma = Gamma(torch.tensor([1.0, 2.0]), torch.tensor([1.0, 0.5])) + self.assertTrue(torch.equal(gamma.concentration, torch.tensor([1.0, 2.0]))) + self.assertTrue(torch.equal(gamma.rate, torch.tensor([1.0, 0.5]))) + self.assertEqual(gamma.batch_shape, torch.Size([2])) + + def test_properties(self): + self.assertTrue(torch.allclose(self.gamma.mean, self.concentration / self.rate)) + self.assertTrue(torch.allclose(self.gamma.mode, ((self.concentration - 1) / self.rate).clamp(min=0))) + self.assertTrue(torch.allclose(self.gamma.variance, self.concentration / self.rate.pow(2))) + + def test_expand(self): + expanded_gamma = self.gamma.expand(torch.Size([3, 2])) + self.assertEqual(expanded_gamma.batch_shape, torch.Size([3, 2])) + self.assertTrue(torch.equal(expanded_gamma.concentration, self.concentration.expand([3, 2]))) + self.assertTrue(torch.equal(expanded_gamma.rate, self.rate.expand([3, 2]))) + + def test_rsample(self): + samples = self.gamma.rsample(sample_shape=torch.Size([10])) + self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape + self.assertTrue(samples.requires_grad) #Check gradient tracking + + + def test_log_prob(self): + value = torch.tensor([1.0, 2.0]) + log_prob = self.gamma.log_prob(value) + expected_log_prob = ( + torch.xlogy(self.concentration, self.rate) + + torch.xlogy(self.concentration - 1, value) + - self.rate * value + - torch.lgamma(self.concentration) + ) + self.assertTrue(torch.allclose(log_prob, expected_log_prob)) + + def test_entropy(self): + entropy = self.gamma.entropy() + expected_entropy = ( + self.concentration + - torch.log(self.rate) + + torch.lgamma(self.concentration) + + (1.0 - self.concentration) * torch.digamma(self.concentration) + ) + self.assertTrue(torch.allclose(entropy, expected_entropy)) + + def test_natural_params(self): + natural_params = self.gamma._natural_params + expected_natural_params = (self.concentration - 1, -self.rate) + self.assertTrue(torch.equal(natural_params[0], expected_natural_params[0])) + self.assertTrue(torch.equal(natural_params[1], expected_natural_params[1])) + + def test_log_normalizer(self): + x, y = self.gamma._natural_params + log_normalizer = self.gamma._log_normalizer(x, y) + expected_log_normalizer = torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) + self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer)) + + def test_cdf(self): + value = torch.tensor([1.0, 2.0]) + cdf = self.gamma.cdf(value) + expected_cdf = torch.special.gammainc(self.concentration, self.rate * value) + self.assertTrue(torch.allclose(cdf, expected_cdf)) + + + def test_invalid_inputs(self): + with self.assertRaises(ValueError): + Gamma(torch.tensor([-1.0, 1.0]), self.rate) # Negative concentration + with self.assertRaises(ValueError): + Gamma(self.concentration, torch.tensor([-1.0, 1.0])) # Negative rate + with self.assertRaises(ValueError): + self.gamma.log_prob(torch.tensor([-1.0, 1.0])) # Negative value + +class TestMixtureSameFamily(unittest.TestCase): + def setUp(self): + # Use simple distributions for testing. Replace with your desired components + component_dist = Normal(torch.tensor([0.0, 1.0]).requires_grad_(True), torch.tensor([1.0, 2.0]).requires_grad_(True)) + mixture_dist = Categorical(torch.tensor([0.6, 0.4])) + + self.mixture = MixtureSameFamily(mixture_dist, component_dist) + + def test_rsample_event_size_1(self): + samples = self.mixture.rsample(sample_shape=torch.Size([10])) + self.assertEqual(samples.shape, torch.Size([10])) + self.assertTrue(samples.requires_grad) # Ensure gradient tracking + + def test_rsample_event_size_greater_than_1(self): + #Create a mixture with event_size > 1 (e.g., using Independent(Normal(...),1) ) + component_dist = Independent(Normal(torch.tensor([[0.0, 1.0], [2., 3.]]).requires_grad_(True), torch.tensor([[1.0, 2.0], [2., 3.]]).requires_grad_(True)), 1) + mixture_dist = Categorical(torch.tensor([0.6, 0.4]).requires_grad_(True)) + mixture = MixtureSameFamily(mixture_dist, component_dist) + samples = mixture.rsample(sample_shape=torch.Size([10])) + self.assertEqual(samples.shape, torch.Size([10, 2])) # Check shape + self.assertTrue(samples.requires_grad) # Ensure gradient tracking + + def test_distributional_transform(self): + # Test cases for different input shapes and component distributions + # Add assertions to check the output of _distributional_transform. + + x = torch.randn(10,2) + transform = self.mixture._distributional_transform(x) + # ADD ASSERTION HERE. The transform should be a tensor of the correct shape, depending on event shape of component distribution + self.assertEqual(transform.shape,torch.Size([10, 2])) + + + def test_invalid_component(self): + # Test with a component distribution that doesn't have rsample + class NoRsampleDist(object): + has_rsample = False + with self.assertRaises(ValueError): + MixtureSameFamily(Categorical(torch.tensor([0.6, 0.4])), NoRsampleDist()) + + def test_log_cdf_multivariate(self): + # Test with a multivariate component distribution (e.g., Independent(Normal(...),1)) + component_dist = Independent(Normal(loc=torch.zeros(2, 2), scale=torch.ones(2, 2)),1) + mixture_dist = Categorical(torch.tensor([0.6, 0.4])) + mixture = MixtureSameFamily(mixture_dist, component_dist) + x = torch.tensor([[0.5, 1.0], [1.0, 0.5]]) + log_cdf = mixture._log_cdf(x) + # ADD ASSERTION(S) HERE to check the calculated log_cdf values + # self.assertTrue(torch.allclose(log_cdf, torch.tensor([expected_value_1, expected_value_2]), atol=1e-4)) + +class TestBetaDistribution(unittest.TestCase): + def setUp(self): + self.concentration1 = torch.tensor([1.0, 2.0], requires_grad=True) + self.concentration0 = torch.tensor([2.0, 1.0], requires_grad=True) + self.beta = Beta(self.concentration1, self.concentration0) + self._dirichlet = Dirichlet(torch.stack([self.concentration1, self.concentration0], -1)) # Initialize _dirichlet + + def test_init(self): + beta = Beta(torch.tensor(1.0), torch.tensor(2.0)) + self.assertEqual(beta.concentration1, 1.0) + self.assertEqual(beta.concentration0, 2.0) + self.assertEqual(beta._gamma1.concentration, 1.0) #Check Gamma parameters + self.assertEqual(beta._gamma0.concentration, 2.0) + + + def test_properties(self): + self.assertTrue(torch.allclose(self.beta.mean, self.concentration1 / (self.concentration1 + self.concentration0))) + self.assertTrue(torch.allclose(self.beta.mode, (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2))) + total = self.concentration1 + self.concentration0 + self.assertTrue(torch.allclose(self.beta.variance, self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)))) + + def test_expand(self): + expanded_beta = self.beta.expand(torch.Size([3, 2])) + self.assertEqual(expanded_beta.batch_shape, torch.Size([3, 2])) + self.assertTrue(torch.equal(expanded_beta._gamma1.concentration, self.concentration1.expand([3, 2]))) + self.assertTrue(torch.equal(expanded_beta._gamma0.concentration, self.concentration0.expand([3, 2]))) + + def test_rsample(self): + samples = self.beta.rsample(sample_shape=torch.Size([10])) + self.assertEqual(samples.shape, torch.Size([10, 2])) + self.assertTrue(samples.requires_grad) #check grad tracking + + def test_log_prob(self): + value = torch.tensor([0.5, 0.7]) + log_prob = self.beta.log_prob(value) + heads_tails = torch.stack([value, 1.0 - value], -1) + expected_log_prob = self._dirichlet.log_prob(heads_tails) #Use initialized _dirichlet + self.assertTrue(torch.allclose(log_prob, expected_log_prob)) + + def test_entropy(self): + entropy = self.beta.entropy() + expected_entropy = self._dirichlet.entropy() #Use initialized _dirichlet + self.assertTrue(torch.allclose(entropy, expected_entropy)) + + + def test_natural_params(self): + natural_params = self.beta._natural_params + self.assertTrue(torch.equal(natural_params[0], self.concentration1)) + self.assertTrue(torch.equal(natural_params[1], self.concentration0)) + + def test_log_normalizer(self): + x, y = self.beta._natural_params + log_normalizer = self.beta._log_normalizer(x, y) + expected_log_normalizer = torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) + self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer)) + + + def test_invalid_inputs(self): + with self.assertRaises(ValueError): + Beta(torch.tensor([-1.0, 1.0]), self.concentration0) # Negative concentration1 + with self.assertRaises(ValueError): + Beta(self.concentration1, torch.tensor([-1.0, 1.0])) # Negative concentration0 + with self.assertRaises(ValueError): + self.beta.log_prob(torch.tensor([-0.1, 0.5])) # Value outside [0,1] + with self.assertRaises(ValueError): + self.beta.log_prob(torch.tensor([1.1, 0.5])) + +class TestDirichlet(unittest.TestCase): + def setUp(self): + self.concentration = torch.tensor([1.0, 2.0, 3.0]).requires_grad_(True) + self.dirichlet = Dirichlet(self.concentration) + + def test_init(self): + with self.assertRaises(ValueError): + Dirichlet(torch.tensor([]), validate_args=True) # Not enough dimensions + dirichlet = Dirichlet(self.concentration) + self.assertTrue(torch.equal(dirichlet.concentration, self.concentration)) + self.assertEqual(dirichlet.batch_shape, torch.Size()) + self.assertEqual(dirichlet.event_shape, torch.Size([3])) + + def test_expand(self): + expanded_dirichlet = self.dirichlet.expand(torch.Size([2, 3])) + self.assertEqual(expanded_dirichlet.batch_shape, torch.Size([2, 3])) + self.assertEqual(expanded_dirichlet.event_shape, torch.Size([3])) + self.assertTrue(torch.equal(expanded_dirichlet.concentration, self.concentration.expand(torch.Size([2, 3, 3])))) + + + def test_log_prob(self): + value = torch.tensor([0.2, 0.3, 0.5]) + log_prob = self.dirichlet.log_prob(value) + expected_log_prob = ( + torch.xlogy(self.concentration - 1.0, value).sum(-1) + + torch.lgamma(self.concentration.sum(-1)) + - torch.lgamma(self.concentration).sum(-1) + ) + self.assertTrue(torch.allclose(log_prob, expected_log_prob)) + + def test_mean(self): + mean = self.dirichlet.mean + expected_mean = self.concentration / self.concentration.sum(-1, True) + self.assertTrue(torch.allclose(mean, expected_mean)) + + def test_mode(self): + mode = self.dirichlet.mode + concentrationm1 = (self.concentration - 1).clamp(min=0.0) + expected_mode = concentrationm1 / concentrationm1.sum(-1, True) + self.assertTrue(torch.allclose(mode, expected_mode)) + + def test_variance(self): + variance = self.dirichlet.variance + con0 = self.concentration.sum(-1, True) + expected_variance = ( + self.concentration + * (con0 - self.concentration) + / (con0.pow(2) * (con0 + 1)) + ) + self.assertTrue(torch.allclose(variance, expected_variance)) + + def test_entropy(self): + entropy = self.dirichlet.entropy() + k = self.concentration.size(-1) + a0 = self.concentration.sum(-1) + expected_entropy = ( + torch.lgamma(self.concentration).sum(-1) + - torch.lgamma(a0) + - (k - a0) * torch.digamma(a0) + - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) + ) + self.assertTrue(torch.allclose(entropy, expected_entropy)) + + + def test_natural_params(self): + natural_params = self.dirichlet._natural_params + self.assertTrue(torch.equal(natural_params[0], self.concentration)) + + def test_log_normalizer(self): + log_normalizer = self.dirichlet._log_normalizer(self.concentration) + expected_log_normalizer = torch.lgamma(self.concentration).sum(-1) - torch.lgamma(self.concentration.sum(-1)) + self.assertTrue(torch.allclose(log_normalizer, expected_log_normalizer)) + + def test_invalid_inputs(self): + with self.assertRaises(ValueError): + Dirichlet(torch.tensor([1.0, -1.0, 3.0])) #Negative Concentration + with self.assertRaises(ValueError): + self.dirichlet.log_prob(torch.tensor([0.2, 0.3, 0.6])) #Values don't sum to 1 + + +class TestStudentT(unittest.TestCase): + + def setUp(self): + self.df = torch.tensor([3.0, 5.0]).requires_grad_(True) + self.loc = torch.tensor([1.0, 2.0]).requires_grad_(True) + self.scale = torch.tensor([0.5, 1.0]).requires_grad_(True) + self.studentt = StudentT(self.df, self.loc, self.scale, validate_args=True) + + def test_init(self): + studentt = StudentT(3.0, 1.0, 0.5) + self.assertEqual(studentt.df, 3.0) + self.assertEqual(studentt.loc, 1.0) + self.assertEqual(studentt.scale, 0.5) + self.assertEqual(studentt.gamma.concentration, 1.5) #Check Gamma initialization + self.assertEqual(studentt.gamma.rate, 1.5) + + def test_properties(self): + df = torch.tensor([.3, 2.0]) + loc = torch.tensor([1.0, 2.0]) + scale = torch.tensor([0.5, 1.0]) + studentt = StudentT(df, loc, scale) + self.assertTrue(torch.equal(studentt.mode, studentt.loc)) + # Check mean (undefined for df <= 1) + # print(self.studentt.mean[0]) + self.assertTrue(torch.isnan(studentt.mean[0])) #Testing for nan values + self.assertTrue(torch.allclose(studentt.mean[1], studentt.loc[1])) #Mean should be defined for df > 1 + + # Check variance (undefined for df <= 1, infinite for 1 < df <= 2) + self.assertTrue(torch.isnan(studentt.variance[0])) + self.assertTrue(torch.isinf(studentt.variance[1])) # Should be inf for 1 < df <=2 + self.assertTrue(torch.allclose(studentt.variance[1], (scale[1].pow(2) * df[1] / (df[1] - 2)))) #Should be defined for df > 2 + + + def test_expand(self): + expanded_studentt = self.studentt.expand(torch.Size([2, 2])) + self.assertEqual(expanded_studentt.batch_shape, torch.Size([2, 2])) + self.assertTrue(torch.equal(expanded_studentt.df, self.df.expand([2, 2]))) + self.assertTrue(torch.equal(expanded_studentt.loc, self.loc.expand([2, 2]))) + self.assertTrue(torch.equal(expanded_studentt.scale, self.scale.expand([2, 2]))) + + + def test_log_prob(self): + value = torch.tensor([2.0, 3.0]) + log_prob = self.studentt.log_prob(value) + y = (value - self.loc) / self.scale + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + 1.0)) + ) + expected_log_prob = -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z + self.assertTrue(torch.allclose(log_prob, expected_log_prob)) + + + def test_entropy(self): + entropy = self.studentt.entropy() + lbeta = ( + torch.lgamma(0.5 * self.df) + + math.lgamma(0.5) + - torch.lgamma(0.5 * (self.df + 1)) + ) + expected_entropy = ( + self.scale.log() + + 0.5 + * (self.df + 1) + * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) + self.assertTrue(torch.allclose(entropy, expected_entropy)) + + + def test_rsample(self): + samples = self.studentt.rsample(sample_shape=torch.Size([10])) + print(samples.shape) + # print(self.studentt.rsample(sample_shape=torch.Size([2]))) + self.assertEqual(samples.shape, torch.Size([10, 2])) + self.assertTrue(samples.requires_grad) # Check that gradients are tracked + + def test_invalid_inputs(self): + with self.assertRaises(ValueError): + StudentT(torch.tensor([-1.0, 1.0]), self.loc, self.scale) #Negative df + with self.assertRaises(ValueError): + self.studentt.log_prob([1, 2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/images/implicit.webp b/images/implicit.webp new file mode 100644 index 0000000..06d71c8 Binary files /dev/null and b/images/implicit.webp differ diff --git a/src/irt/.ipynb_checkpoints/distributions-checkpoint.py b/src/irt/.ipynb_checkpoints/distributions-checkpoint.py new file mode 100644 index 0000000..43197d9 --- /dev/null +++ b/src/irt/.ipynb_checkpoints/distributions-checkpoint.py @@ -0,0 +1,1470 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number, Real +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.autograd.functional import jacobian +from torch.distributions import ( + Bernoulli, + Binomial, + ContinuousBernoulli, + Distribution, + Geometric, + NegativeBinomial, + RelaxedBernoulli, + constraints, +) +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all, lazy_property +from torch.types import _size +from torch.distributions.distribution import Distribution + +default_size = torch.Size() + + +class Beta(ExponentialFamily): + r""" + Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. + + Example:: + + >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) + >>> m.sample() + tensor([0.1046]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.unit_interval + has_rsample = True + + def __init__( + self, concentration1: torch.Tensor, concentration0: torch.Tensor, validate_args: Optional[bool] = None + ) -> None: + """ + Initializes the Beta distribution with the given concentration parameters. + + Args: + concentration1: First concentration parameter (alpha). + concentration0: Second concentration parameter (beta). + validate_args: If True, validates the distribution's parameters. + """ + self.concentration1 = concentration1 + self.concentration0 = concentration0 + self._gamma1 = Gamma(self.concentration1, torch.ones_like(concentration1), validate_args=validate_args) + self._gamma0 = Gamma(self.concentration0, torch.ones_like(concentration0), validate_args=validate_args) + self._dirichlet = Dirichlet(torch.stack([self.concentration1, self.concentration0], -1)) + + super().__init__(self._gamma0._batch_shape, validate_args=validate_args) + + def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta": + """ + Expands the Beta distribution to a new batch shape. + + Args: + batch_shape: Desired batch shape. + _instance: Instance to validate. + + Returns: + A new Beta distribution instance with expanded parameters. + """ + new = self._get_checked_instance(Beta, _instance) + batch_shape = torch.Size(batch_shape) + new._gamma1 = self._gamma1.expand(batch_shape) + new._gamma0 = self._gamma0.expand(batch_shape) + super(Beta, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> torch.Tensor: + """ + Computes the mean of the Beta distribution. + + Returns: + torch.Tensor: Mean of the distribution. + """ + return self.concentration1 / (self.concentration1 + self.concentration0) + + @property + def mode(self) -> torch.Tensor: + """ + Computes the mode of the Beta distribution. + + Returns: + torch.Tensor: Mode of the distribution. + """ + return (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2) + + @property + def variance(self) -> torch.Tensor: + """ + Computes the variance of the Beta distribution. + + Returns: + torch.Tensor: Variance of the distribution. + """ + total = self.concentration1 + self.concentration0 + return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) + + def rsample(self, sample_shape: _size = ()) -> torch.Tensor: + """ + Generates a reparameterized sample from the Beta distribution. + + Args: + sample_shape (_size): Shape of the sample. + + Returns: + torch.Tensor: Sample from the Beta distribution. + """ + z1 = self._gamma1.rsample(sample_shape) + z0 = self._gamma0.rsample(sample_shape) + return z1 / (z1 + z0) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the log probability density of a value under the Beta distribution. + + Args: + value: Value to evaluate. + + Returns: + Log probability of the value. + """ + if self._validate_args: + self._validate_sample(value) + heads_tails = torch.stack([value, 1.0 - value], -1) + return self._dirichlet.log_prob(heads_tails) + + def entropy(self) -> torch.Tensor: + """ + Computes the entropy of the Beta distribution. + + Returns: + Entropy of the distribution. + """ + return self._dirichlet.entropy() + + @property + def _natural_params(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns the natural parameters of the distribution. + + Returns: + Natural parameters. + """ + return self.concentration1, self.concentration0 + + def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the log normalizer for the natural parameters. + + Args: + x: Parameter 1. + y: Parameter 2. + + Returns: + Log normalizer value. + """ + return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) + + +class Dirichlet(ExponentialFamily): + """ + Dirichlet distribution parameterized by a concentration vector. + + The Dirichlet distribution is a multivariate generalization of the Beta distribution. It + is commonly used in Bayesian statistics, particularly for modeling proportions. + """ + + arg_constraints = {"concentration": constraints.independent(constraints.positive, 1)} + support = constraints.simplex + has_rsample = True + + def __init__(self, concentration: torch.Tensor, validate_args: Optional[bool] = None) -> None: + """ + Initializes the Dirichlet distribution. + + Args: + concentration: Positive concentration parameter vector (alpha). + validate_args: If True, validates the distribution's parameters. + """ + if torch.numel(concentration) < 1: + raise ValueError("`concentration` parameter must be at least one-dimensional.") + self.concentration = concentration + self.gamma = Gamma(self.concentration, torch.ones_like(self.concentration)) + batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + @property + def mean(self) -> torch.Tensor: + """ + Computes the mean of the Dirichlet distribution. + + Returns: + Mean vector, calculated as `concentration / concentration.sum(-1, keepdim=True)`. + """ + return self.concentration / self.concentration.sum(-1, keepdim=True) + + @property + def mode(self) -> torch.Tensor: + """ + Computes the mode of the Dirichlet distribution. + + Note: + - The mode is defined only when all concentration values are > 1. + - For concentrations ≤ 1, the mode vector is clamped to enforce positivity. + + Returns: + Mode vector. + """ + concentration_minus_one = (self.concentration - 1).clamp(min=0.0) + mode = concentration_minus_one / concentration_minus_one.sum(-1, keepdim=True) + mask = (self.concentration < 1).all(dim=-1) + mode[mask] = F.one_hot(mode[mask].argmax(dim=-1), concentration_minus_one.shape[-1]).to(mode) + return mode + + @property + def variance(self) -> torch.Tensor: + """ + Computes the variance of the Dirichlet distribution. + + Returns: + Variance vector for each component. + """ + total_concentration = self.concentration.sum(-1, keepdim=True) + return ( + self.concentration + * (total_concentration - self.concentration) + / (total_concentration.pow(2) * (total_concentration + 1)) + ) + + def rsample(self, sample_shape: _size = ()) -> torch.Tensor: + """ + Generates a reparameterized sample from the Dirichlet distribution. + + Args: + sample_shape (_size): Desired sample shape. + + Returns: + torch.Tensor: A reparameterized sample. + """ + z = self.gamma.rsample(sample_shape) # Sample from underlying Gamma distribution + + return z / torch.sum(z, dim=-1, keepdims=True) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the log probability density for a given value. + + Args: + value (torch.Tensor): Value to evaluate the log probability at. + + Returns: + torch.Tensor: Log probability density of the value. + """ + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration - 1.0, value).sum(-1) + + torch.lgamma(self.concentration.sum(-1)) + - torch.lgamma(self.concentration).sum(-1) + ) + + def entropy(self) -> torch.Tensor: + """ + Computes the entropy of the Dirichlet distribution. + + Returns: + torch.Tensor: Entropy of the distribution. + """ + k = self.concentration.size(-1) + total_concentration = self.concentration.sum(-1) + return ( + torch.lgamma(self.concentration).sum(-1) + - torch.lgamma(total_concentration) + - (k - total_concentration) * torch.digamma(total_concentration) + - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) + ) + + def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet": + """ + Expands the distribution parameters to a new batch shape. + + Args: + batch_shape (torch.Size): Desired batch shape. + _instance (Optional): Instance to validate. + + Returns: + A new Dirichlet distribution instance with expanded parameters. + """ + new = self._get_checked_instance(Dirichlet, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape + self.event_shape) + super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def _natural_params(self) -> tuple: + """ + Returns the natural parameters of the distribution. + + Returns: + tuple: Natural parameter tuple `(concentration,)`. + """ + return (self.concentration,) + + def _log_normalizer(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the log normalizer for the natural parameters. + + Args: + x (torch.Tensor): Natural parameter. + + Returns: + torch.Tensor: Log normalizer value. + """ + return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) + + +class StudentT(Distribution): + """ + Student's t-distribution parameterized by degrees of freedom (df), location (loc), and scale (scale). + + This distribution is commonly used for robust statistical modeling, particularly when the data + may have outliers or heavier tails than a Normal distribution. + """ + + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real, + "scale": constraints.positive, + } + support = constraints.real + has_rsample = True + + def __init__( + self, df: torch.Tensor, loc: float = 0.0, scale: float = 1.0, validate_args: Optional[bool] = None + ) -> None: + """ + Initializes the Student's t-distribution. + + Args: + df (torch.Tensor): Degrees of freedom (must be positive). + loc (float or torch.Tensor): Location parameter (default: 0.0). + scale (float or torch.Tensor): Scale parameter (default: 1.0). + validate_args (Optional[bool]): If True, validates distribution parameters. + """ + self.df, self.loc, self.scale = broadcast_all(df, loc, scale) + self.gamma = Gamma(self.df * 0.5, self.df * 0.5) + batch_shape = self.df.size() + super().__init__(batch_shape, validate_args=validate_args) + + @property + def mean(self) -> torch.Tensor: + """ + Computes the mean of the distribution. + + Note: The mean is undefined when `df <= 1`. + + Returns: + torch.Tensor: Mean of the distribution, or NaN for undefined cases. + """ + m = self.loc.clone(memory_format=torch.contiguous_format) + m[self.df <= 1] = float("nan") # Mean is undefined for df <= 1 + return m + + @property + def mode(self) -> torch.Tensor: + """ + Computes the mode of the distribution. + + Returns: + torch.Tensor: Mode of the distribution, which is equal to `loc`. + """ + return self.loc + + @property + def variance(self) -> torch.Tensor: + """ + Computes the variance of the distribution. + + Note: + - Variance is infinite for 1 < df <= 2. + - Variance is undefined (NaN) for df <= 1. + + Returns: + torch.Tensor: Variance of the distribution, or appropriate values for edge cases. + """ + m = self.df.clone(memory_format=torch.contiguous_format) + # Variance for df > 2 + m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2) + # Infinite variance for 1 < df <= 2 + m[(self.df <= 2) & (self.df > 1)] = float("inf") + # Undefined variance for df <= 1 + m[self.df <= 1] = float("nan") + return m + + def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT": + """ + Expands the distribution parameters to a new batch shape. + + Args: + batch_shape (torch.Size): Desired batch size for the expanded distribution. + _instance (Optional): Instance to validate. + + Returns: + StudentT: A new StudentT distribution with expanded parameters. + """ + new = self._get_checked_instance(StudentT, _instance) + batch_shape = torch.Size(batch_shape) + new.df = self.df.expand(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(StudentT, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the log probability density for a given value. + + Args: + value (torch.Tensor): Value to evaluate the log probability at. + + Returns: + torch.Tensor: Log probability density of the given value. + """ + if self._validate_args: + self._validate_sample(value) + y = (value - self.loc) / self.scale + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + 1.0)) + ) + return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z + + def entropy(self) -> torch.Tensor: + """ + Computes the entropy of the Student's t-distribution. + + Returns: + torch.Tensor: Entropy of the distribution. + """ + lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1)) + return ( + self.scale.log() + + 0.5 * (self.df + 1) * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) + + def _transform(self, z: torch.Tensor) -> torch.Tensor: + """ + Transforms an input tensor `z` to a standardized form based on the location and scale. + + Args: + z (torch.Tensor): Input tensor to transform. + + Returns: + torch.Tensor: Transformed tensor representing the standardized form. + """ + return (z - self.loc) / self.scale + + def _d_transform_d_z(self) -> torch.Tensor: + """ + Computes the derivative of the transform function with respect to `z`. + + Returns: + torch.Tensor: Reciprocal of the scale, representing the gradient for reparameterization. + """ + return 1 / self.scale + + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: + """ + Generates a reparameterized sample from the Student's t-distribution. + + Args: + sample_shape (_size): Shape of the sample. + + Returns: + torch.Tensor: Reparameterized sample, enabling gradient tracking. + """ + self.loc = self.loc.expand(self._extended_shape(sample_shape)) + self.scale = self.scale.expand(self._extended_shape(sample_shape)) + + sigma = self.gamma.rsample() + + # Sample from Normal distribution (shape must match after broadcasting) + x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape) + + transform = self._transform(x.detach()) # Standardize the sample + surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient + + return x + (surrogate_x - surrogate_x.detach()) + + +class Gamma(ExponentialFamily): + """ + Gamma distribution parameterized by `concentration` (shape) and `rate` (inverse scale). + The Gamma distribution is often used to model the time until an event occurs, + and it is a continuous probability distribution defined for non-negative real values. + """ + + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + def __init__( + self, + concentration: torch.Tensor, + rate: torch.Tensor, + validate_args: Optional[bool] = None, + ) -> None: + """ + Initializes the Gamma distribution. + + Args: + concentration (torch.Tensor): Shape parameter of the distribution (often referred to as alpha). + rate (torch.Tensor): Rate parameter (inverse of scale, often referred to as beta). + validate_args (Optional[bool]): If True, validates the distribution's parameters. + """ + self.concentration, self.rate = broadcast_all(concentration, rate) + if isinstance(concentration, Number) and isinstance(rate, Number): + batch_shape = torch.Size() + else: + batch_shape = self.concentration.size() + super().__init__(batch_shape, validate_args=validate_args) + + @property + def mean(self) -> torch.Tensor: + """ + Computes the mean of the Gamma distribution. + + Returns: + torch.Tensor: Mean of the distribution, calculated as `concentration / rate`. + """ + return self.concentration / self.rate + + @property + def mode(self) -> torch.Tensor: + """ + Computes the mode of the Gamma distribution. + + Note: + - The mode is defined only for `concentration > 1`. For `concentration <= 1`, + the mode is clamped to 0. + + Returns: + torch.Tensor: Mode of the distribution. + """ + return ((self.concentration - 1) / self.rate).clamp(min=0) + + @property + def variance(self) -> torch.Tensor: + """ + Computes the variance of the Gamma distribution. + + Returns: + torch.Tensor: Variance of the distribution, calculated as `concentration / rate^2`. + """ + return self.concentration / self.rate.pow(2) + + def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma": + """ + Expands the distribution parameters to a new batch shape. + + Args: + batch_shape (torch.Size): Desired batch shape. + _instance (Optional): Instance to validate. + + Returns: + Gamma: A new Gamma distribution instance with expanded parameters. + """ + new = self._get_checked_instance(Gamma, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Gamma, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: + """ + Generates a reparameterized sample from the Gamma distribution. + + Args: + sample_shape (_size): Shape of the sample. + + Returns: + torch.Tensor: A reparameterized sample. + """ + shape = self._extended_shape(sample_shape) + concentration = self.concentration.expand(shape) + rate = self.rate.expand(shape) + + # Generate a sample using the underlying C++ implementation for efficiency + value = torch._standard_gamma(concentration) / rate.detach() + + # Detach u for surrogate computation + u = value.detach() * rate.detach() / rate + value = value + (u - u.detach()) + + # Ensure numerical stability for gradients + value.detach().clamp_(min=torch.finfo(value.dtype).tiny) + return value + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the log probability density for a given value. + + Args: + value (torch.Tensor): Value to evaluate the log probability at. + + Returns: + torch.Tensor: Log probability density of the given value. + """ + value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device) + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration, self.rate) + + torch.xlogy(self.concentration - 1, value) + - self.rate * value + - torch.lgamma(self.concentration) + ) + + def entropy(self) -> torch.Tensor: + """ + Computes the entropy of the Gamma distribution. + + Returns: + torch.Tensor: Entropy of the distribution. + """ + return ( + self.concentration + - torch.log(self.rate) + + torch.lgamma(self.concentration) + + (1.0 - self.concentration) * torch.digamma(self.concentration) + ) + + @property + def _natural_params(self) -> tuple: + """ + Returns the natural parameters of the distribution. + + Returns: + tuple: Tuple of natural parameters `(concentration - 1, -rate)`. + """ + return self.concentration - 1, -self.rate + + def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the log normalizer for the natural parameters. + + Args: + x (torch.Tensor): First natural parameter. + y (torch.Tensor): Second natural parameter. + + Returns: + torch.Tensor: Log normalizer value. + """ + return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) + + def cdf(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the cumulative distribution function (CDF) for the Gamma distribution. + + Args: + value (torch.Tensor): Value to evaluate the CDF at. + + Returns: + torch.Tensor: CDF of the given value. + """ + if self._validate_args: + self._validate_sample(value) + return torch.special.gammainc(self.concentration, self.rate * value) + + +class Normal(ExponentialFamily): + """ + Represents the Normal (Gaussian) distribution with specified mean (loc) and standard deviation (scale). + Inherits from PyTorch's ExponentialFamily distribution class. + """ + + has_rsample = True + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + + def __init__( + self, + loc: torch.Tensor, + scale: torch.Tensor, + validate_args: Optional[bool] = None, + ) -> None: + """ + Initializes the Normal distribution. + + Args: + loc (torch.Tensor): Mean (location) parameter of the distribution. + scale (torch.Tensor): Standard deviation (scale) parameter of the distribution. + validate_args (Optional[bool]): If True, checks the distribution parameters for validity. + """ + self.loc, self.scale = broadcast_all(loc, scale) + # Determine batch shape based on the type of `loc` and `scale`. + batch_shape = torch.Size() if isinstance(loc, Number) and isinstance(scale, Number) else self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + @property + def mean(self) -> torch.Tensor: + """ + Returns the mean of the distribution. + + Returns: + torch.Tensor: The mean (location) parameter `loc`. + """ + return self.loc + + @property + def mode(self) -> torch.Tensor: + """ + Returns the mode of the distribution. + + Returns: + torch.Tensor: The mode (equal to `loc` in a Normal distribution). + """ + return self.loc + + @property + def stddev(self) -> torch.Tensor: + """ + Returns the standard deviation of the distribution. + + Returns: + torch.Tensor: The standard deviation (scale) parameter `scale`. + """ + return self.scale + + @property + def variance(self) -> torch.Tensor: + """ + Returns the variance of the distribution. + + Returns: + torch.Tensor: The variance, computed as `scale ** 2`. + """ + return self.stddev.pow(2) + + def entropy(self) -> torch.Tensor: + """ + Computes the entropy of the distribution. + + Returns: + torch.Tensor: The entropy of the Normal distribution, which is a measure of uncertainty. + """ + return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) + + def cdf(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the cumulative distribution function (CDF) of the distribution at a given value. + + Args: + value (torch.Tensor): The value at which to evaluate the CDF. + + Returns: + torch.Tensor: The probability that a random variable from the distribution is less than or equal to `value`. + """ + return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2)))) + + def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal": + """ + Expands the distribution parameters to a new batch shape. + + Args: + batch_shape (torch.Size): Desired batch size for the expanded distribution. + _instance (Optional): Instance to check for validity. + + Returns: + Normal: A new Normal distribution with parameters expanded to the specified batch shape. + """ + new = self._get_checked_instance(Normal, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Normal, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def icdf(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the inverse cumulative distribution function (quantile function) at a given value. + + Args: + value (torch.Tensor): The probability value at which to evaluate the inverse CDF. + + Returns: + torch.Tensor: The quantile corresponding to `value`. + """ + return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + + def log_prob(self, value: torch.Tensor) -> torch.Tensor: + """ + Computes the log probability density of the distribution at a given value. + + Args: + value (torch.Tensor): The value at which to evaluate the log probability. + + Returns: + torch.Tensor: The log probability density at `value`. + """ + var = self.scale**2 + log_scale = self.scale.log() if not isinstance(self.scale, Real) else math.log(self.scale) + return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + + def _transform(self, z: torch.Tensor) -> torch.Tensor: + """ + Transforms an input tensor `z` to a standardized form based on the mean and scale. + + Args: + z (torch.Tensor): Input tensor to transform. + + Returns: + torch.Tensor: The transformed tensor, representing the standardized normal form. + """ + return (z - self.loc) / self.scale + + def _d_transform_d_z(self) -> torch.Tensor: + """ + Computes the derivative of the transform function with respect to `z`. + + Returns: + torch.Tensor: The reciprocal of the scale, representing the gradient for reparameterization. + """ + return 1 / self.scale + + def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: + """ + Generates a sample from the Normal distribution using `torch.normal`. + + Args: + sample_shape (torch.Size): Shape of the sample to generate. + + Returns: + torch.Tensor: A tensor with samples from the Normal distribution, detached from the computation graph. + """ + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) + + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: + """ + Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation. + + Returns: + torch.Tensor: A tensor containing a reparameterized sample, useful for gradient-based optimization. + """ + # Sample a point from the distribution + x = self.sample(sample_shape) + # Transform the sample to standard normal form + transform = self._transform(x) + # Compute a surrogate value for backpropagation + surrogate_x = -transform / self._d_transform_d_z().detach() + # Return the sample with gradient tracking enabled + return x + (surrogate_x - surrogate_x.detach()) + + +class MixtureSameFamily(torch.distributions.MixtureSameFamily): + """ + Represents a mixture of distributions from the same family. + Supporting reparameterized sampling for gradient-based optimization. + """ + + has_rsample = True + + def __init__(self, *args, **kwargs) -> None: + """ + Initializes the MixtureSameFamily distribution and checks if the component distributions. + Support reparameterized sampling (required for `rsample`). + + Raises: + ValueError: If the component distributions do not support reparameterized sampling. + """ + super().__init__(*args, **kwargs) + if not self._component_distribution.has_rsample: + raise ValueError("Cannot reparameterize a mixture of non-reparameterizable components.") + + # Define a list of discrete distributions for checking in `_log_cdf` + self.discrete_distributions: List[Distribution] = [ + Bernoulli, + Binomial, + ContinuousBernoulli, + Geometric, + NegativeBinomial, + RelaxedBernoulli, + ] + + def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: + """ + Generates a reparameterized sample from the mixture of distributions. + + This method generates a sample, applies a distributional transformation, + and computes a surrogate sample to enable gradient flow during optimization. + + Args: + sample_shape (torch.Size): The shape of the sample to generate. + + Returns: + torch.Tensor: A reparameterized sample with gradients enabled. + """ + # Generate a sample from the mixture distribution + x = self.sample(sample_shape=sample_shape) + event_size = math.prod(self.event_shape) + + if event_size != 1: + # For multi-dimensional events, use reshaped distributional transformations + def reshaped_dist_trans(input_x: torch.Tensor) -> torch.Tensor: + return torch.reshape(self._distributional_transform(input_x), (-1, event_size)) + + def reshaped_dist_trans_summed(x_2d: torch.Tensor) -> torch.Tensor: + return torch.sum(reshaped_dist_trans(x_2d), dim=0) + + x_2d = x.reshape((-1, event_size)) + transform_2d = reshaped_dist_trans(x) + jac = jacobian(reshaped_dist_trans_summed, x_2d).detach().movedim(1, 0) + surrogate_x_2d = -torch.linalg.solve_triangular(jac.detach(), transform_2d[..., None], upper=False) + surrogate_x = surrogate_x_2d.reshape(x.shape) + else: + # For one-dimensional events, apply the standard distributional transformation + transform = self._distributional_transform(x) + log_prob_x = self.log_prob(x) + + if self._event_ndims > 1: + log_prob_x = log_prob_x.reshape(log_prob_x.shape + (1,) * self._event_ndims) + + surrogate_x = -transform * torch.exp(-log_prob_x.detach()) + + return x + (surrogate_x - surrogate_x.detach()) + + def _distributional_transform(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies a distributional transformation to the input sample `x`, using cumulative + distribution functions (CDFs) and posterior weights. + + Args: + x (torch.Tensor): The input sample to transform. + + Returns: + torch.Tensor: The transformed tensor based on the mixture model's CDFs. + """ + if isinstance(self._component_distribution, torch.distributions.Independent): + univariate_components = self._component_distribution.base_dist + else: + univariate_components = self._component_distribution + + # Expand input tensor and compute log-probabilities in each component + x = self._pad(x) # [S, B, 1, E] + log_prob_x = univariate_components.log_prob(x) # [S, B, K, E] + + event_size = math.prod(self.event_shape) + + if event_size != 1: + # CDF transformation for multi-dimensional events + cumsum_log_prob_x = log_prob_x.reshape(-1, event_size) + cumsum_log_prob_x = torch.cumsum(cumsum_log_prob_x, dim=-1) + cumsum_log_prob_x = cumsum_log_prob_x.roll(shifts=1, dims=-1) + cumsum_log_prob_x[:, 0] = 0 + cumsum_log_prob_x = cumsum_log_prob_x.reshape(log_prob_x.shape) + + logits_mix_prob = self._pad_mixture_dimensions(self._mixture_distribution.logits) + log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x + + component_axis = -self._event_ndims - 1 + cdf_x = univariate_components.cdf(x) + posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=component_axis) + else: + # CDF transformation for one-dimensional events + log_posterior_weights_x = self._mixture_distribution.logits + component_axis = -self._event_ndims - 1 + cdf_x = univariate_components.cdf(x) + posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=-1) + posterior_weights_x = self._pad_mixture_dimensions(posterior_weights_x) + + return torch.sum(posterior_weights_x * cdf_x, dim=component_axis) + + def _log_cdf(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the logarithm of the cumulative distribution function (CDF) for the mixture distribution. + + Args: + x (torch.Tensor): The input tensor for which to compute the log CDF. + + Returns: + torch.Tensor: The log CDF values. + """ + x = self._pad(x) + if isinstance(self._component_distribution, torch.distributions.Independent): + univariate_components = self._component_distribution.base_dist + else: + univariate_components = self._component_distribution + + if callable(getattr(univariate_components, "_log_cdf", None)): + log_cdf_x = univariate_components._log_cdf(x) + else: + log_cdf_x = torch.log(univariate_components.cdf(x)) + + if isinstance(univariate_components, tuple(self.discrete_distributions)): + log_mix_prob = torch.sigmoid(self._mixture_distribution.logits) + else: + log_mix_prob = F.log_softmax(self._mixture_distribution.logits, dim=-1) + + return torch.logsumexp(log_cdf_x + log_mix_prob, dim=-1) + + +def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: + """ + Evaluate a polynomial at given points. + + Args: + y: Input tensor. + coeffs: Polynomial coefficients. + + Returns: + Evaluated polynomial tensor. + """ + coef = list(coef) + result = coef.pop() + while coef: + result = coef.pop() + y * result + return result + + +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] + +_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] +_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] + + +def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor: + """ + Compute the logarithm of the modified Bessel function of the first kind. + + Args: + x: Input tensor, must be positive. + order: Order of the Bessel function (0 or 1). + + Returns: + Logarithm of the Bessel function. + """ + if order not in {0, 1}: + raise ValueError("Order must be 0 or 1.") + + # compute small solution + y = x / 3.75 + y = y * y + small = _eval_poly(y, _COEF_SMALL[order]) + if order == 1: + small = x.abs() * small + small = small.log() + + # compute large solution + y = 3.75 / x + large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() + + result = torch.where(x < 3.75, small, large) + return result + + +@torch.jit.script_if_tracing +def _rejection_sample( + loc: torch.Tensor, + concentration: torch.Tensor, + proposal_r: torch.Tensor, + x: torch.Tensor +) -> torch.Tensor: + """ + Perform rejection sampling for the von Mises distribution. + + Args: + loc: Location parameter. + concentration: Concentration parameter. + proposal_r: Precomputed proposal parameter. + x: Tensor to fill with samples. + + Returns: + Tensor of samples. + """ + done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return (x + math.pi + loc) % (2 * math.pi) - math.pi + + +class VonMises(Distribution): + """ + Von Mises distribution class for circular data. + """ + + arg_constraints = { + "loc": constraints.real, + "concentration": constraints.positive, + } + support = constraints.real + has_rsample = True + + def __init__( + self, + loc: torch.Tensor, + concentration: torch.Tensor, + validate_args: bool = None, + ): + self.loc, self.concentration = broadcast_all(loc, concentration) + batch_shape = self.loc.shape + super().__init__(batch_shape, torch.Size(), validate_args) + + @lazy_property + @torch.no_grad() + def _proposal_r(self) -> torch.Tensor: + """ + Compute the proposal parameter for sampling. + """ + kappa = self._concentration + tau = 1 + (1 + 4 * kappa**2).sqrt() + rho = (tau - (2 * tau).sqrt()) / (2 * kappa) + _proposal_r = (1 + rho**2) / (2 * rho) + + # second order Taylor expansion around 0 for small kappa + _proposal_r_taylor = 1 / kappa + kappa + return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r) + + def log_prob(self, value): + """ + Compute the log probability of the given value. + + Args: + value: Tensor of values. + + Returns: + Tensor of log probabilities. + """ + if self._validate_args: + self._validate_sample(value) + log_prob = self.concentration * torch.cos(value - self.loc) + log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0) + return log_prob + + @lazy_property + def _loc(self): + return self.loc.to(torch.double) + + @lazy_property + def _concentration(self): + return self.concentration.to(torch.double) + + @torch.no_grad() + def sample(self, sample_shape=torch.Size()): + """ + The sampling algorithm for the von Mises distribution is based on the + following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the + von Mises distribution." Applied Statistics (1979): 152-157. + + Sampling is always done in double precision internally to avoid a hang + in _rejection_sample() for small values of the concentration, which + starts to happen for single precision around 1e-4 (see issue #88443). + """ + shape = self._extended_shape(sample_shape) + x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device) + return _rejection_sample( + self._loc, self._concentration, self._proposal_r, x + ).to(self.loc.dtype) + + def rsample(self, sample_shape=torch.Size()): + """ + Generate reparameterized samples from the distribution. + """ + shape = self._extended_shape(sample_shape) + samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape) + samples = samples + self.loc + + # Map the samples to [-pi, pi]. + return samples - 2. * torch.pi * torch.round(samples / (2. * torch.pi)) + + @property + def mean(self): + """Mean of the distribution.""" + return self.loc + + @property + def variance(self): + """Variance of the distribution.""" + return 1 - ( + _log_modified_bessel_fn(self.concentration, order=1) + - _log_modified_bessel_fn(self.concentration, order=0) + ).exp() + +@torch.jit.script_if_tracing +@torch.no_grad() +def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, shape: torch.Size) -> torch.Tensor: + """ + Perform rejection sampling to draw samples from the von Mises distribution. + + Args: + concentration (torch.Tensor): Concentration parameter (kappa) of the distribution. + proposal_r (torch.Tensor): Proposal distribution parameter. + shape (torch.Size): Desired shape of the samples. + + Returns: + torch.Tensor: Samples from the von Mises distribution. + """ + x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device) + done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device) + + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return x + +def cosxm1(x: torch.Tensor) -> torch.Tensor: + """ + Compute cos(x) - 1 using a numerically stable formula. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor, `cos(x) - 1`. + """ + return -2 * torch.square(torch.sin(x / 2.0)) + +class _VonMisesSampler(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + concentration: torch.Tensor, + proposal_r: torch.Tensor, + shape: torch.Size, + ) -> torch.Tensor: + """ + Perform forward sampling using rejection sampling. + + Args: + ctx (torch.autograd.function.FunctionCtx): Context object for saving tensors. + concentration (torch.Tensor): Concentration parameter (kappa). + proposal_r (torch.Tensor): Proposal distribution parameter. + shape (torch.Size): Desired shape of the samples. + + Returns: + torch.Tensor: Samples from the von Mises distribution. + """ + samples = _rejection_rsample(concentration, proposal_r, shape) + ctx.save_for_backward(concentration, proposal_r, samples) + + return samples + + @staticmethod + @torch.autograd.function.once_differentiable + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None, None]: + """ + Compute gradients for backward pass using implicit reparameterization. + + Args: + ctx (torch.autograd.function.FunctionCtx): Context object containing saved tensors. + grad_output (torch.Tensor): Gradient of the loss with respect to the output. + + Returns: + Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors. + """ + concentration, proposal_r, samples = ctx.saved_tensors + + num_periods = torch.round(samples / (2. * torch.pi)) + x_mapped = samples - (2. * torch.pi) * num_periods + + ## Parameters from the paper + ck = 10.5 + num_terms = 20 + + ## Compute series and normal approximation + cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms) + cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration) + use_series = concentration < ck + cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods + dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal) + + ## Compute CDF gradient terms + inv_prob = torch.exp(concentration * cosxm1(samples)) / ( + 2 * math.pi * torch.special.i0e(concentration) + ) + grad_concentration = grad_output*(-dcdf_dconcentration / inv_prob) + + return grad_concentration, None, None + + +def von_mises_cdf_series( + x: torch.Tensor, concentration: torch.Tensor, num_terms: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the CDF of the von Mises distribution using a series approximation. + + Args: + x (torch.Tensor): Input tensor. + concentration (torch.Tensor): Concentration parameter (kappa). + num_terms (int): Number of terms in the series. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration. + """ + vn = torch.zeros_like(x) + dvn_dconcentration = torch.zeros_like(x) + + n = torch.tensor(num_terms, dtype=x.dtype, device=x.device) + rn = torch.zeros_like(x) + drn_dconcentration = torch.zeros_like(x) + + while n > 0: + denominator = 2. * n / concentration + rn + ddenominator_dk = -2. * n / concentration ** 2 + drn_dconcentration + rn = 1. / denominator + drn_dconcentration = -ddenominator_dk / denominator ** 2 + + multiplier = torch.sin(n * x) / n + vn + vn = rn * multiplier + dvn_dconcentration = (drn_dconcentration * multiplier + rn * dvn_dconcentration) + + n -= 1 + + cdf = 0.5 + x / (2. * torch.pi) + vn / torch.pi + dcdf_dconcentration = dvn_dconcentration / torch.pi + + cdf_clipped = torch.clamp(cdf, 0., 1.) + dcdf_dconcentration *= (cdf >= 0.) & (cdf <= 1.) + + return cdf_clipped, dcdf_dconcentration + +def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Approximate the CDF of the von Mises distribution. + + Args: + concentration (torch.Tensor): Concentration parameter (kappa). + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Approximate CDF values. + """ + + # Calculate the z value based on the approximation + z = (torch.sqrt(torch.tensor(2. / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x) + # Apply corrections to z to improve the approximation + z2 = z ** 2 + z3 = z2 * z + z4 = z2 ** 2 + c = 24. * concentration + c1 = 56. + + xi = z - z3 / ( + ((c - 2. * z2 - 16.) / 3.) - + (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.) + ) ** 2 + + # Use the standard normal distribution for the approximation + distrib = torch.distributions.Normal( + torch.tensor(0., dtype=x.dtype, device=x.device), + torch.tensor(1., dtype=x.dtype, device=x.device) + ) + + return distrib.cdf(xi) + +def von_mises_cdf_normal( + x: torch.Tensor, concentration: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the CDF of the von Mises distribution using a normal approximation. + + Args: + x (torch.Tensor): Input tensor. + concentration (torch.Tensor): Concentration parameter (kappa). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration. + """ + with torch.enable_grad(): + concentration_ = concentration.detach().clone().requires_grad_(True) + cdf = cdf_func(concentration_, x) + cdf.backward(torch.ones_like(cdf)) # Compute gradients + dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient + # Detach gradients to prevent further autograd tracking + concentration_.grad = None + return cdf, dcdf_dconcentration \ No newline at end of file diff --git a/src/irt/distributions.py b/src/irt/distributions.py index 1511abf..43197d9 100644 --- a/src/irt/distributions.py +++ b/src/irt/distributions.py @@ -17,8 +17,9 @@ constraints, ) from torch.distributions.exp_family import ExponentialFamily -from torch.distributions.utils import broadcast_all +from torch.distributions.utils import broadcast_all, lazy_property from torch.types import _size +from torch.distributions.distribution import Distribution default_size = torch.Size() @@ -1035,3 +1036,435 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor: log_mix_prob = F.log_softmax(self._mixture_distribution.logits, dim=-1) return torch.logsumexp(log_cdf_x + log_mix_prob, dim=-1) + + +def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor: + """ + Evaluate a polynomial at given points. + + Args: + y: Input tensor. + coeffs: Polynomial coefficients. + + Returns: + Evaluated polynomial tensor. + """ + coef = list(coef) + result = coef.pop() + while coef: + result = coef.pop() + y * result + return result + + +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] + +_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] +_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] + + +def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor: + """ + Compute the logarithm of the modified Bessel function of the first kind. + + Args: + x: Input tensor, must be positive. + order: Order of the Bessel function (0 or 1). + + Returns: + Logarithm of the Bessel function. + """ + if order not in {0, 1}: + raise ValueError("Order must be 0 or 1.") + + # compute small solution + y = x / 3.75 + y = y * y + small = _eval_poly(y, _COEF_SMALL[order]) + if order == 1: + small = x.abs() * small + small = small.log() + + # compute large solution + y = 3.75 / x + large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() + + result = torch.where(x < 3.75, small, large) + return result + + +@torch.jit.script_if_tracing +def _rejection_sample( + loc: torch.Tensor, + concentration: torch.Tensor, + proposal_r: torch.Tensor, + x: torch.Tensor +) -> torch.Tensor: + """ + Perform rejection sampling for the von Mises distribution. + + Args: + loc: Location parameter. + concentration: Concentration parameter. + proposal_r: Precomputed proposal parameter. + x: Tensor to fill with samples. + + Returns: + Tensor of samples. + """ + done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return (x + math.pi + loc) % (2 * math.pi) - math.pi + + +class VonMises(Distribution): + """ + Von Mises distribution class for circular data. + """ + + arg_constraints = { + "loc": constraints.real, + "concentration": constraints.positive, + } + support = constraints.real + has_rsample = True + + def __init__( + self, + loc: torch.Tensor, + concentration: torch.Tensor, + validate_args: bool = None, + ): + self.loc, self.concentration = broadcast_all(loc, concentration) + batch_shape = self.loc.shape + super().__init__(batch_shape, torch.Size(), validate_args) + + @lazy_property + @torch.no_grad() + def _proposal_r(self) -> torch.Tensor: + """ + Compute the proposal parameter for sampling. + """ + kappa = self._concentration + tau = 1 + (1 + 4 * kappa**2).sqrt() + rho = (tau - (2 * tau).sqrt()) / (2 * kappa) + _proposal_r = (1 + rho**2) / (2 * rho) + + # second order Taylor expansion around 0 for small kappa + _proposal_r_taylor = 1 / kappa + kappa + return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r) + + def log_prob(self, value): + """ + Compute the log probability of the given value. + + Args: + value: Tensor of values. + + Returns: + Tensor of log probabilities. + """ + if self._validate_args: + self._validate_sample(value) + log_prob = self.concentration * torch.cos(value - self.loc) + log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0) + return log_prob + + @lazy_property + def _loc(self): + return self.loc.to(torch.double) + + @lazy_property + def _concentration(self): + return self.concentration.to(torch.double) + + @torch.no_grad() + def sample(self, sample_shape=torch.Size()): + """ + The sampling algorithm for the von Mises distribution is based on the + following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the + von Mises distribution." Applied Statistics (1979): 152-157. + + Sampling is always done in double precision internally to avoid a hang + in _rejection_sample() for small values of the concentration, which + starts to happen for single precision around 1e-4 (see issue #88443). + """ + shape = self._extended_shape(sample_shape) + x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device) + return _rejection_sample( + self._loc, self._concentration, self._proposal_r, x + ).to(self.loc.dtype) + + def rsample(self, sample_shape=torch.Size()): + """ + Generate reparameterized samples from the distribution. + """ + shape = self._extended_shape(sample_shape) + samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape) + samples = samples + self.loc + + # Map the samples to [-pi, pi]. + return samples - 2. * torch.pi * torch.round(samples / (2. * torch.pi)) + + @property + def mean(self): + """Mean of the distribution.""" + return self.loc + + @property + def variance(self): + """Variance of the distribution.""" + return 1 - ( + _log_modified_bessel_fn(self.concentration, order=1) + - _log_modified_bessel_fn(self.concentration, order=0) + ).exp() + +@torch.jit.script_if_tracing +@torch.no_grad() +def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, shape: torch.Size) -> torch.Tensor: + """ + Perform rejection sampling to draw samples from the von Mises distribution. + + Args: + concentration (torch.Tensor): Concentration parameter (kappa) of the distribution. + proposal_r (torch.Tensor): Proposal distribution parameter. + shape (torch.Size): Desired shape of the samples. + + Returns: + torch.Tensor: Samples from the von Mises distribution. + """ + x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device) + done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device) + + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return x + +def cosxm1(x: torch.Tensor) -> torch.Tensor: + """ + Compute cos(x) - 1 using a numerically stable formula. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor, `cos(x) - 1`. + """ + return -2 * torch.square(torch.sin(x / 2.0)) + +class _VonMisesSampler(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + concentration: torch.Tensor, + proposal_r: torch.Tensor, + shape: torch.Size, + ) -> torch.Tensor: + """ + Perform forward sampling using rejection sampling. + + Args: + ctx (torch.autograd.function.FunctionCtx): Context object for saving tensors. + concentration (torch.Tensor): Concentration parameter (kappa). + proposal_r (torch.Tensor): Proposal distribution parameter. + shape (torch.Size): Desired shape of the samples. + + Returns: + torch.Tensor: Samples from the von Mises distribution. + """ + samples = _rejection_rsample(concentration, proposal_r, shape) + ctx.save_for_backward(concentration, proposal_r, samples) + + return samples + + @staticmethod + @torch.autograd.function.once_differentiable + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> Tuple[torch.Tensor, None, None]: + """ + Compute gradients for backward pass using implicit reparameterization. + + Args: + ctx (torch.autograd.function.FunctionCtx): Context object containing saved tensors. + grad_output (torch.Tensor): Gradient of the loss with respect to the output. + + Returns: + Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors. + """ + concentration, proposal_r, samples = ctx.saved_tensors + + num_periods = torch.round(samples / (2. * torch.pi)) + x_mapped = samples - (2. * torch.pi) * num_periods + + ## Parameters from the paper + ck = 10.5 + num_terms = 20 + + ## Compute series and normal approximation + cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms) + cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration) + use_series = concentration < ck + cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods + dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal) + + ## Compute CDF gradient terms + inv_prob = torch.exp(concentration * cosxm1(samples)) / ( + 2 * math.pi * torch.special.i0e(concentration) + ) + grad_concentration = grad_output*(-dcdf_dconcentration / inv_prob) + + return grad_concentration, None, None + + +def von_mises_cdf_series( + x: torch.Tensor, concentration: torch.Tensor, num_terms: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the CDF of the von Mises distribution using a series approximation. + + Args: + x (torch.Tensor): Input tensor. + concentration (torch.Tensor): Concentration parameter (kappa). + num_terms (int): Number of terms in the series. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration. + """ + vn = torch.zeros_like(x) + dvn_dconcentration = torch.zeros_like(x) + + n = torch.tensor(num_terms, dtype=x.dtype, device=x.device) + rn = torch.zeros_like(x) + drn_dconcentration = torch.zeros_like(x) + + while n > 0: + denominator = 2. * n / concentration + rn + ddenominator_dk = -2. * n / concentration ** 2 + drn_dconcentration + rn = 1. / denominator + drn_dconcentration = -ddenominator_dk / denominator ** 2 + + multiplier = torch.sin(n * x) / n + vn + vn = rn * multiplier + dvn_dconcentration = (drn_dconcentration * multiplier + rn * dvn_dconcentration) + + n -= 1 + + cdf = 0.5 + x / (2. * torch.pi) + vn / torch.pi + dcdf_dconcentration = dvn_dconcentration / torch.pi + + cdf_clipped = torch.clamp(cdf, 0., 1.) + dcdf_dconcentration *= (cdf >= 0.) & (cdf <= 1.) + + return cdf_clipped, dcdf_dconcentration + +def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Approximate the CDF of the von Mises distribution. + + Args: + concentration (torch.Tensor): Concentration parameter (kappa). + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Approximate CDF values. + """ + + # Calculate the z value based on the approximation + z = (torch.sqrt(torch.tensor(2. / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x) + # Apply corrections to z to improve the approximation + z2 = z ** 2 + z3 = z2 * z + z4 = z2 ** 2 + c = 24. * concentration + c1 = 56. + + xi = z - z3 / ( + ((c - 2. * z2 - 16.) / 3.) - + (z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.) + ) ** 2 + + # Use the standard normal distribution for the approximation + distrib = torch.distributions.Normal( + torch.tensor(0., dtype=x.dtype, device=x.device), + torch.tensor(1., dtype=x.dtype, device=x.device) + ) + + return distrib.cdf(xi) + +def von_mises_cdf_normal( + x: torch.Tensor, concentration: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute the CDF of the von Mises distribution using a normal approximation. + + Args: + x (torch.Tensor): Input tensor. + concentration (torch.Tensor): Concentration parameter (kappa). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration. + """ + with torch.enable_grad(): + concentration_ = concentration.detach().clone().requires_grad_(True) + cdf = cdf_func(concentration_, x) + cdf.backward(torch.ones_like(cdf)) # Compute gradients + dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient + # Detach gradients to prevent further autograd tracking + concentration_.grad = None + return cdf, dcdf_dconcentration \ No newline at end of file