diff --git a/src/irt/distributions.py b/src/irt/distributions.py index 075e72e..1511abf 100644 --- a/src/irt/distributions.py +++ b/src/irt/distributions.py @@ -20,6 +20,8 @@ from torch.distributions.utils import broadcast_all from torch.types import _size +default_size = torch.Size() + class Beta(ExponentialFamily): r""" @@ -491,7 +493,7 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Student's t-distribution. @@ -501,14 +503,13 @@ def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: Returns: torch.Tensor: Reparameterized sample, enabling gradient tracking. """ - loc = self.loc.expand(self._extended_shape(sample_shape)) - scale = self.scale.expand(self._extended_shape(sample_shape)) + self.loc = self.loc.expand(self._extended_shape(sample_shape)) + self.scale = self.scale.expand(self._extended_shape(sample_shape)) - # Sample from auxiliary Gamma distribution - sigma = self.gamma.rsample(sample_shape) + sigma = self.gamma.rsample() # Sample from Normal distribution (shape must match after broadcasting) - x = loc + scale * Normal(0, sigma).rsample() + x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape) transform = self._transform(x.detach()) # Standardize the sample surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient @@ -605,7 +606,7 @@ def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) - new._validate_args = self._validate_args return new - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Gamma distribution. @@ -858,7 +859,7 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: + def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: """ Generates a sample from the Normal distribution using `torch.normal`. @@ -872,7 +873,7 @@ def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: with torch.no_grad(): return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) - def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: _size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation. @@ -919,7 +920,7 @@ def __init__(self, *args, **kwargs) -> None: RelaxedBernoulli, ] - def rsample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor: """ Generates a reparameterized sample from the mixture of distributions.