Skip to content

Commit 04af12a

Browse files
Fixed codestyle in distributions.py
1 parent 3bb83fa commit 04af12a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/irt/distributions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: allow-untyped-defs
22
import math
33
from numbers import Number, Real
4-
from typing import List, Optional, Self, Tuple
4+
from typing import List, Optional, Tuple
55

66
import torch
77
import torch.nn.functional as F
@@ -63,7 +63,7 @@ def __init__(
6363

6464
super().__init__(self._gamma0._batch_shape, validate_args=validate_args)
6565

66-
def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self:
66+
def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta":
6767
"""
6868
Expands the Beta distribution to a new batch shape.
6969
@@ -293,7 +293,7 @@ def entropy(self) -> torch.Tensor:
293293
- ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
294294
)
295295

296-
def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self:
296+
def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet":
297297
"""
298298
Expands the distribution parameters to a new batch shape.
299299
@@ -412,7 +412,7 @@ def variance(self) -> torch.Tensor:
412412
m[self.df <= 1] = float("nan")
413413
return m
414414

415-
def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self:
415+
def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT":
416416
"""
417417
Expands the distribution parameters to a new batch shape.
418418
@@ -585,7 +585,7 @@ def variance(self) -> torch.Tensor:
585585
"""
586586
return self.concentration / self.rate.pow(2)
587587

588-
def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self:
588+
def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma":
589589
"""
590590
Expands the distribution parameters to a new batch shape.
591591
@@ -791,7 +791,7 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:
791791
"""
792792
return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2))))
793793

794-
def expand(self, batch_shape: torch.Size, _instance: Optional[Self] = None) -> Self:
794+
def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal":
795795
"""
796796
Expands the distribution parameters to a new batch shape.
797797

0 commit comments

Comments
 (0)