20
20
from torch .distributions .utils import broadcast_all
21
21
from torch .types import _size
22
22
23
+ default_size = torch .Size ()
24
+
23
25
24
26
class Beta (ExponentialFamily ):
25
27
r"""
@@ -491,7 +493,7 @@ def _d_transform_d_z(self) -> torch.Tensor:
491
493
"""
492
494
return 1 / self .scale
493
495
494
- def rsample (self , sample_shape : _size = torch . Size ) -> torch .Tensor :
496
+ def rsample (self , sample_shape : _size = default_size ) -> torch .Tensor :
495
497
"""
496
498
Generates a reparameterized sample from the Student's t-distribution.
497
499
@@ -501,14 +503,13 @@ def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor:
501
503
Returns:
502
504
torch.Tensor: Reparameterized sample, enabling gradient tracking.
503
505
"""
504
- loc = self .loc .expand (self ._extended_shape (sample_shape ))
505
- scale = self .scale .expand (self ._extended_shape (sample_shape ))
506
+ self . loc = self .loc .expand (self ._extended_shape (sample_shape ))
507
+ self . scale = self .scale .expand (self ._extended_shape (sample_shape ))
506
508
507
- # Sample from auxiliary Gamma distribution
508
- sigma = self .gamma .rsample (sample_shape )
509
+ sigma = self .gamma .rsample ()
509
510
510
511
# Sample from Normal distribution (shape must match after broadcasting)
511
- x = loc + scale * Normal (0 , sigma ).rsample ()
512
+ x = self . loc + self . scale * Normal (0 , sigma ).rsample (sample_shape )
512
513
513
514
transform = self ._transform (x .detach ()) # Standardize the sample
514
515
surrogate_x = - transform / self ._d_transform_d_z ().detach () # Compute surrogate gradient
@@ -605,7 +606,7 @@ def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -
605
606
new ._validate_args = self ._validate_args
606
607
return new
607
608
608
- def rsample (self , sample_shape : _size = torch . Size ) -> torch .Tensor :
609
+ def rsample (self , sample_shape : _size = default_size ) -> torch .Tensor :
609
610
"""
610
611
Generates a reparameterized sample from the Gamma distribution.
611
612
@@ -858,7 +859,7 @@ def _d_transform_d_z(self) -> torch.Tensor:
858
859
"""
859
860
return 1 / self .scale
860
861
861
- def sample (self , sample_shape : torch .Size = torch . Size ) -> torch .Tensor :
862
+ def sample (self , sample_shape : torch .Size = default_size ) -> torch .Tensor :
862
863
"""
863
864
Generates a sample from the Normal distribution using `torch.normal`.
864
865
@@ -872,7 +873,7 @@ def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor:
872
873
with torch .no_grad ():
873
874
return torch .normal (self .loc .expand (shape ), self .scale .expand (shape ))
874
875
875
- def rsample (self , sample_shape : _size = torch . Size ) -> torch .Tensor :
876
+ def rsample (self , sample_shape : _size = default_size ) -> torch .Tensor :
876
877
"""
877
878
Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation.
878
879
@@ -919,7 +920,7 @@ def __init__(self, *args, **kwargs) -> None:
919
920
RelaxedBernoulli ,
920
921
]
921
922
922
- def rsample (self , sample_shape : torch .Size = torch . Size ) -> torch .Tensor :
923
+ def rsample (self , sample_shape : torch .Size = default_size ) -> torch .Tensor :
923
924
"""
924
925
Generates a reparameterized sample from the mixture of distributions.
925
926
0 commit comments