Skip to content

Commit 5595085

Browse files
authored
Update distributions.py
1 parent 1ad8325 commit 5595085

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/irt/distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
505505
self.loc = self.loc.expand(self._extended_shape(sample_shape))
506506
self.scale = self.scale.expand(self._extended_shape(sample_shape))
507507
gamma_samples = Gamma(self.df * 0.5, self.df * 0.5).rsample(sample_shape)
508-
normal_samples = Normal(0., 1.).sample(sample_shape)
508+
normal_samples = Normal(torch.zeros(gamma_samples.shape), torch.ones(gamma_samples.shape)).sample()
509509

510510
# Sample from Normal distribution (shape must match after broadcasting)
511511
x = self.loc.detach() + self.scale.detach() * normal_samples * torch.rsqrt(gamma_samples)

0 commit comments

Comments
 (0)