diff --git a/src/irt/distributions.py b/src/irt/distributions.py index 075e72e..f4bcc80 100644 --- a/src/irt/distributions.py +++ b/src/irt/distributions.py @@ -501,14 +501,11 @@ def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: Returns: torch.Tensor: Reparameterized sample, enabling gradient tracking. """ - loc = self.loc.expand(self._extended_shape(sample_shape)) - scale = self.scale.expand(self._extended_shape(sample_shape)) - # Sample from auxiliary Gamma distribution - sigma = self.gamma.rsample(sample_shape) + sigma = self.gamma.rsample() # Sample from Normal distribution (shape must match after broadcasting) - x = loc + scale * Normal(0, sigma).rsample() + x = loc + scale * Normal(0, sigma).rsample(sample_shape) transform = self._transform(x.detach()) # Standardize the sample surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient