Skip to content

Commit 1ad8325

Browse files
authored
Update distributions.py
fix Student Distribution
1 parent ce069e0 commit 1ad8325

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/irt/distributions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def __init__(
366366
validate_args (Optional[bool]): If True, validates distribution parameters.
367367
"""
368368
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
369-
self.gamma = Gamma(self.df * 0.5, self.df * 0.5)
370369
batch_shape = self.df.size()
371370
super().__init__(batch_shape, validate_args=validate_args)
372371

@@ -505,11 +504,11 @@ def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
505504
"""
506505
self.loc = self.loc.expand(self._extended_shape(sample_shape))
507506
self.scale = self.scale.expand(self._extended_shape(sample_shape))
508-
509-
sigma = self.gamma.rsample()
510-
507+
gamma_samples = Gamma(self.df * 0.5, self.df * 0.5).rsample(sample_shape)
508+
normal_samples = Normal(0., 1.).sample(sample_shape)
509+
511510
# Sample from Normal distribution (shape must match after broadcasting)
512-
x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape)
511+
x = self.loc.detach() + self.scale.detach() * normal_samples * torch.rsqrt(gamma_samples)
513512

514513
transform = self._transform(x.detach()) # Standardize the sample
515514
surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient

0 commit comments

Comments
 (0)