|
1 | 1 | # mypy: allow-untyped-defs
|
2 | 2 | import math
|
3 | 3 | from numbers import Number, Real
|
4 |
| -from typing import List, Optional, Self, Tuple |
| 4 | +from typing import List, Optional, Tuple |
5 | 5 |
|
6 | 6 | import torch
|
7 | 7 | import torch.nn.functional as F
|
@@ -63,7 +63,7 @@ def __init__(
|
63 | 63 |
|
64 | 64 | super().__init__(self._gamma0._batch_shape, validate_args=validate_args)
|
65 | 65 |
|
66 |
| - def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self: |
| 66 | + def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta": |
67 | 67 | """
|
68 | 68 | Expands the Beta distribution to a new batch shape.
|
69 | 69 |
|
@@ -293,7 +293,7 @@ def entropy(self) -> torch.Tensor:
|
293 | 293 | - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
|
294 | 294 | )
|
295 | 295 |
|
296 |
| - def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self: |
| 296 | + def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet": |
297 | 297 | """
|
298 | 298 | Expands the distribution parameters to a new batch shape.
|
299 | 299 |
|
@@ -412,7 +412,7 @@ def variance(self) -> torch.Tensor:
|
412 | 412 | m[self.df <= 1] = float("nan")
|
413 | 413 | return m
|
414 | 414 |
|
415 |
| - def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self: |
| 415 | + def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT": |
416 | 416 | """
|
417 | 417 | Expands the distribution parameters to a new batch shape.
|
418 | 418 |
|
@@ -585,7 +585,7 @@ def variance(self) -> torch.Tensor:
|
585 | 585 | """
|
586 | 586 | return self.concentration / self.rate.pow(2)
|
587 | 587 |
|
588 |
| - def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self: |
| 588 | + def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma": |
589 | 589 | """
|
590 | 590 | Expands the distribution parameters to a new batch shape.
|
591 | 591 |
|
@@ -791,7 +791,7 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:
|
791 | 791 | """
|
792 | 792 | return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2))))
|
793 | 793 |
|
794 |
| - def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self: |
| 794 | + def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal": |
795 | 795 | """
|
796 | 796 | Expands the distribution parameters to a new batch shape.
|
797 | 797 |
|
|
0 commit comments