From a2795f93a93d83b14d385c53ba4b56accdcb2f2e Mon Sep 17 00:00:00 2001 From: NikitinaMaria <71255897+NikitinaMaria@users.noreply.github.com> Date: Mon, 25 Nov 2024 21:18:43 +0300 Subject: [PATCH] Add files via upload --- src/irt/distributions.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/irt/distributions.py b/src/irt/distributions.py index 075e72e..1511abf 100644 --- a/src/irt/distributions.py +++ b/src/irt/distributions.py @@ -20,6 +20,8 @@ from torch.distributions.utils import broadcast_all from torch.types import _size +default_size = torch.Size() + class Beta(ExponentialFamily): r""" @@ -491,7 +493,7 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Student's t-distribution. @@ -501,14 +503,13 @@ def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: Returns: torch.Tensor: Reparameterized sample, enabling gradient tracking. """ - loc = self.loc.expand(self._extended_shape(sample_shape)) - scale = self.scale.expand(self._extended_shape(sample_shape)) + self.loc = self.loc.expand(self._extended_shape(sample_shape)) + self.scale = self.scale.expand(self._extended_shape(sample_shape)) - # Sample from auxiliary Gamma distribution - sigma = self.gamma.rsample(sample_shape) + sigma = self.gamma.rsample() # Sample from Normal distribution (shape must match after broadcasting) - x = loc + scale * Normal(0, sigma).rsample() + x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape) transform = self._transform(x.detach()) # Standardize the sample surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient @@ -605,7 +606,7 @@ def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) - new._validate_args = self._validate_args return new - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Gamma distribution. @@ -858,7 +859,7 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: + def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: """ Generates a sample from the Normal distribution using `torch.normal`. @@ -872,7 +873,7 @@ def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: with torch.no_grad(): return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation. @@ -919,7 +920,7 @@ def __init__(self, *args, **kwargs) -> None: RelaxedBernoulli, ] - def rsample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the mixture of distributions.