Skip to content

Commit a2795f9

Browse files
Add files via upload
1 parent a661563 commit a2795f9

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/irt/distributions.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch.distributions.utils import broadcast_all
2121
from torch.types import _size
2222

23+
default_size = torch.Size()
24+
2325

2426
class Beta(ExponentialFamily):
2527
r"""
@@ -491,7 +493,7 @@ def _d_transform_d_z(self) -> torch.Tensor:
491493
"""
492494
return 1 / self.scale
493495

494-
def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor:
496+
def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
495497
"""
496498
Generates a reparameterized sample from the Student's t-distribution.
497499
@@ -501,14 +503,13 @@ def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor:
501503
Returns:
502504
torch.Tensor: Reparameterized sample, enabling gradient tracking.
503505
"""
504-
loc = self.loc.expand(self._extended_shape(sample_shape))
505-
scale = self.scale.expand(self._extended_shape(sample_shape))
506+
self.loc = self.loc.expand(self._extended_shape(sample_shape))
507+
self.scale = self.scale.expand(self._extended_shape(sample_shape))
506508

507-
# Sample from auxiliary Gamma distribution
508-
sigma = self.gamma.rsample(sample_shape)
509+
sigma = self.gamma.rsample()
509510

510511
# Sample from Normal distribution (shape must match after broadcasting)
511-
x = loc + scale * Normal(0, sigma).rsample()
512+
x = self.loc + self.scale * Normal(0, sigma).rsample(sample_shape)
512513

513514
transform = self._transform(x.detach()) # Standardize the sample
514515
surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient
@@ -605,7 +606,7 @@ def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -
605606
new._validate_args = self._validate_args
606607
return new
607608

608-
def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor:
609+
def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
609610
"""
610611
Generates a reparameterized sample from the Gamma distribution.
611612
@@ -858,7 +859,7 @@ def _d_transform_d_z(self) -> torch.Tensor:
858859
"""
859860
return 1 / self.scale
860861

861-
def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor:
862+
def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
862863
"""
863864
Generates a sample from the Normal distribution using `torch.normal`.
864865
@@ -872,7 +873,7 @@ def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor:
872873
with torch.no_grad():
873874
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
874875

875-
def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor:
876+
def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
876877
"""
877878
Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation.
878879
@@ -919,7 +920,7 @@ def __init__(self, *args, **kwargs) -> None:
919920
RelaxedBernoulli,
920921
]
921922

922-
def rsample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor:
923+
def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
923924
"""
924925
Generates a reparameterized sample from the mixture of distributions.
925926

0 commit comments

Comments
 (0)