diff --git a/.flake8 b/.flake8 index 6b1f535..5121202 100644 --- a/.flake8 +++ b/.flake8 @@ -9,6 +9,12 @@ extend-ignore = N812, B010, ANN101, + D100, + D212, + ANN002, + ANN003, + D205, + D415 max-line-length = 120 max-complexity = 15 docstring-convention = google @@ -23,6 +29,7 @@ ignore-names = X_train, X_control, X, + Z, X_val, X_valid, X_test, diff --git a/.github/workflows/linters.yml b/.github/workflows/linters.yml index fad77a5..abbb01e 100644 --- a/.github/workflows/linters.yml +++ b/.github/workflows/linters.yml @@ -27,10 +27,10 @@ jobs: pip install isort flake8 black - name: Run isort - run: isort --check-only . + run: isort --check-only ./src/irt - name: Run flake8 - run: flake8 . + run: flake8 ./src/irt - name: Run black - run: black --line-length=120 --check --verbose --diff --color . + run: black --line-length=120 --check --verbose --diff --color ./src/irt diff --git a/src/irt/distributions.py b/src/irt/distributions.py index f53d2e9..075e72e 100644 --- a/src/irt/distributions.py +++ b/src/irt/distributions.py @@ -1,20 +1,20 @@ # mypy: allow-untyped-defs import math from numbers import Number, Real -from typing import Optional, List +from typing import List, Optional, Tuple import torch import torch.nn.functional as F from torch.autograd.functional import jacobian from torch.distributions import ( - constraints, - Distribution, Bernoulli, Binomial, ContinuousBernoulli, + Distribution, Geometric, NegativeBinomial, RelaxedBernoulli, + constraints, ) from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all @@ -37,6 +37,7 @@ class Beta(ExponentialFamily): concentration0 (float or Tensor): 2nd concentration parameter of the distribution (often referred to as beta) """ + arg_constraints = { "concentration1": constraints.positive, "concentration0": constraints.positive, @@ -44,14 +45,16 @@ class Beta(ExponentialFamily): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, concentration1: torch.Tensor, concentration0: torch.Tensor, validate_args: Optional[bool] = None + ) -> None: """ Initializes the Beta distribution with the given concentration parameters. Args: - concentration1 (Tensor): First concentration parameter (alpha). - concentration0 (Tensor): Second concentration parameter (beta). - validate_args (bool): If True, validates the distribution's parameters. + concentration1: First concentration parameter (alpha). + concentration0: Second concentration parameter (beta). + validate_args: If True, validates the distribution's parameters. """ self.concentration1 = concentration1 self.concentration0 = concentration0 @@ -61,16 +64,16 @@ def __init__(self, concentration1, concentration0, validate_args=None): super().__init__(self._gamma0._batch_shape, validate_args=validate_args) - def expand(self, batch_shape, _instance=None): + def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta": """ Expands the Beta distribution to a new batch shape. Args: - batch_shape (torch.Size): Desired batch shape. - _instance (Optional): Instance to validate. + batch_shape: Desired batch shape. + _instance: Instance to validate. Returns: - Beta: A new Beta distribution instance with expanded parameters. + A new Beta distribution instance with expanded parameters. """ new = self._get_checked_instance(Beta, _instance) batch_shape = torch.Size(batch_shape) @@ -81,7 +84,7 @@ def expand(self, batch_shape, _instance=None): return new @property - def mean(self): + def mean(self) -> torch.Tensor: """ Computes the mean of the Beta distribution. @@ -91,7 +94,7 @@ def mean(self): return self.concentration1 / (self.concentration1 + self.concentration0) @property - def mode(self): + def mode(self) -> torch.Tensor: """ Computes the mode of the Beta distribution. @@ -101,7 +104,7 @@ def mode(self): return (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2) @property - def variance(self): + def variance(self) -> torch.Tensor: """ Computes the variance of the Beta distribution. @@ -125,53 +128,54 @@ def rsample(self, sample_shape: _size = ()) -> torch.Tensor: z0 = self._gamma0.rsample(sample_shape) return z1 / (z1 + z0) - def log_prob(self, value): + def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ Computes the log probability density of a value under the Beta distribution. Args: - value (torch.Tensor): Value to evaluate. + value: Value to evaluate. Returns: - torch.Tensor: Log probability of the value. + Log probability of the value. """ if self._validate_args: self._validate_sample(value) heads_tails = torch.stack([value, 1.0 - value], -1) return self._dirichlet.log_prob(heads_tails) - def entropy(self): + def entropy(self) -> torch.Tensor: """ Computes the entropy of the Beta distribution. Returns: - torch.Tensor: Entropy of the distribution. + Entropy of the distribution. """ return self._dirichlet.entropy() @property - def _natural_params(self): + def _natural_params(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns the natural parameters of the distribution. Returns: - Tuple[torch.Tensor, torch.Tensor]: Natural parameters. + Natural parameters. """ return self.concentration1, self.concentration0 - def _log_normalizer(self, x, y): + def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the log normalizer for the natural parameters. Args: - x (torch.Tensor): Parameter 1. - y (torch.Tensor): Parameter 2. + x: Parameter 1. + y: Parameter 2. Returns: - torch.Tensor: Log normalizer value. + Log normalizer value. """ return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) + class Dirichlet(ExponentialFamily): """ Dirichlet distribution parameterized by a concentration vector. @@ -179,24 +183,21 @@ class Dirichlet(ExponentialFamily): The Dirichlet distribution is a multivariate generalization of the Beta distribution. It is commonly used in Bayesian statistics, particularly for modeling proportions. """ - arg_constraints = { - "concentration": constraints.independent(constraints.positive, 1) - } + + arg_constraints = {"concentration": constraints.independent(constraints.positive, 1)} support = constraints.simplex has_rsample = True - def __init__(self, concentration: torch.Tensor, validate_args: Optional[bool] = None): + def __init__(self, concentration: torch.Tensor, validate_args: Optional[bool] = None) -> None: """ Initializes the Dirichlet distribution. Args: - concentration (torch.Tensor): Positive concentration parameter vector (alpha). - validate_args (Optional[bool]): If True, validates the distribution's parameters. + concentration: Positive concentration parameter vector (alpha). + validate_args: If True, validates the distribution's parameters. """ if torch.numel(concentration) < 1: - raise ValueError( - "`concentration` parameter must be at least one-dimensional." - ) + raise ValueError("`concentration` parameter must be at least one-dimensional.") self.concentration = concentration self.gamma = Gamma(self.concentration, torch.ones_like(self.concentration)) batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] @@ -208,7 +209,7 @@ def mean(self) -> torch.Tensor: Computes the mean of the Dirichlet distribution. Returns: - torch.Tensor: Mean vector, calculated as `concentration / concentration.sum(-1, keepdim=True)`. + Mean vector, calculated as `concentration / concentration.sum(-1, keepdim=True)`. """ return self.concentration / self.concentration.sum(-1, keepdim=True) @@ -222,14 +223,12 @@ def mode(self) -> torch.Tensor: - For concentrations ≤ 1, the mode vector is clamped to enforce positivity. Returns: - torch.Tensor: Mode vector. + Mode vector. """ concentration_minus_one = (self.concentration - 1).clamp(min=0.0) mode = concentration_minus_one / concentration_minus_one.sum(-1, keepdim=True) mask = (self.concentration < 1).all(dim=-1) - mode[mask] = F.one_hot( - mode[mask].argmax(dim=-1), concentration_minus_one.shape[-1] - ).to(mode) + mode[mask] = F.one_hot(mode[mask].argmax(dim=-1), concentration_minus_one.shape[-1]).to(mode) return mode @property @@ -238,7 +237,7 @@ def variance(self) -> torch.Tensor: Computes the variance of the Dirichlet distribution. Returns: - torch.Tensor: Variance vector for each component. + Variance vector for each component. """ total_concentration = self.concentration.sum(-1, keepdim=True) return ( @@ -258,8 +257,8 @@ def rsample(self, sample_shape: _size = ()) -> torch.Tensor: torch.Tensor: A reparameterized sample. """ z = self.gamma.rsample(sample_shape) # Sample from underlying Gamma distribution - - return z/torch.sum(z, dim=-1, keepdims=True) + + return z / torch.sum(z, dim=-1, keepdims=True) def log_prob(self, value: torch.Tensor) -> torch.Tensor: """ @@ -295,7 +294,7 @@ def entropy(self) -> torch.Tensor: - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) ) - def expand(self, batch_shape: torch.Size, _instance=None) -> "Dirichlet": + def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet": """ Expands the distribution parameters to a new batch shape. @@ -304,14 +303,12 @@ def expand(self, batch_shape: torch.Size, _instance=None) -> "Dirichlet": _instance (Optional): Instance to validate. Returns: - Dirichlet: A new Dirichlet distribution instance with expanded parameters. + A new Dirichlet distribution instance with expanded parameters. """ new = self._get_checked_instance(Dirichlet, _instance) batch_shape = torch.Size(batch_shape) new.concentration = self.concentration.expand(batch_shape + self.event_shape) - super(Dirichlet, new).__init__( - batch_shape, self.event_shape, validate_args=False - ) + super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new @@ -345,6 +342,7 @@ class StudentT(Distribution): This distribution is commonly used for robust statistical modeling, particularly when the data may have outliers or heavier tails than a Normal distribution. """ + arg_constraints = { "df": constraints.positive, "loc": constraints.real, @@ -353,7 +351,9 @@ class StudentT(Distribution): support = constraints.real has_rsample = True - def __init__(self, df: torch.Tensor, loc: float = 0.0, scale: float = 1.0, validate_args: Optional[bool] = None): + def __init__( + self, df: torch.Tensor, loc: float = 0.0, scale: float = 1.0, validate_args: Optional[bool] = None + ) -> None: """ Initializes the Student's t-distribution. @@ -379,7 +379,7 @@ def mean(self) -> torch.Tensor: torch.Tensor: Mean of the distribution, or NaN for undefined cases. """ m = self.loc.clone(memory_format=torch.contiguous_format) - m[self.df <= 1] = float('nan') # Mean is undefined for df <= 1 + m[self.df <= 1] = float("nan") # Mean is undefined for df <= 1 return m @property @@ -406,14 +406,14 @@ def variance(self) -> torch.Tensor: """ m = self.df.clone(memory_format=torch.contiguous_format) # Variance for df > 2 - m[self.df > 2] = (self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)) + m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2) # Infinite variance for 1 < df <= 2 - m[(self.df <= 2) & (self.df > 1)] = float('inf') + m[(self.df <= 2) & (self.df > 1)] = float("inf") # Undefined variance for df <= 1 - m[self.df <= 1] = float('nan') + m[self.df <= 1] = float("nan") return m - def expand(self, batch_shape: torch.Size, _instance=None) -> "StudentT": + def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT": """ Expands the distribution parameters to a new batch shape. @@ -462,16 +462,10 @@ def entropy(self) -> torch.Tensor: Returns: torch.Tensor: Entropy of the distribution. """ - lbeta = ( - torch.lgamma(0.5 * self.df) - + math.lgamma(0.5) - - torch.lgamma(0.5 * (self.df + 1)) - ) + lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1)) return ( self.scale.log() - + 0.5 - * (self.df + 1) - * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * (self.df + 1) * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + 0.5 * self.df.log() + lbeta ) @@ -497,7 +491,7 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: """ Generates a reparameterized sample from the Student's t-distribution. @@ -505,7 +499,7 @@ def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: sample_shape (_size): Shape of the sample. Returns: - torch.Tensor: Reparameterized sample, enabling gradient tracking. + torch.Tensor: Reparameterized sample, enabling gradient tracking. """ loc = self.loc.expand(self._extended_shape(sample_shape)) scale = self.scale.expand(self._extended_shape(sample_shape)) @@ -525,10 +519,10 @@ def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: class Gamma(ExponentialFamily): """ Gamma distribution parameterized by `concentration` (shape) and `rate` (inverse scale). - The Gamma distribution is often used to model the time until an event occurs, and it is a continuous probability distribution defined for non-negative real values. """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, @@ -542,7 +536,7 @@ def __init__( concentration: torch.Tensor, rate: torch.Tensor, validate_args: Optional[bool] = None, - ): + ) -> None: """ Initializes the Gamma distribution. @@ -592,7 +586,7 @@ def variance(self) -> torch.Tensor: """ return self.concentration / self.rate.pow(2) - def expand(self, batch_shape: torch.Size, _instance=None) -> "Gamma": + def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma": """ Expands the distribution parameters to a new batch shape. @@ -611,7 +605,7 @@ def expand(self, batch_shape: torch.Size, _instance=None) -> "Gamma": new._validate_args = self._validate_args return new - def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: """ Generates a reparameterized sample from the Gamma distribution. @@ -627,7 +621,7 @@ def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: # Generate a sample using the underlying C++ implementation for efficiency value = torch._standard_gamma(concentration) / rate.detach() - + # Detach u for surrogate computation u = value.detach() * rate.detach() / rate value = value + (u - u.detach()) @@ -713,7 +707,7 @@ class Normal(ExponentialFamily): Represents the Normal (Gaussian) distribution with specified mean (loc) and standard deviation (scale). Inherits from PyTorch's ExponentialFamily distribution class. """ - + has_rsample = True arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real @@ -741,7 +735,7 @@ def __init__( def mean(self) -> torch.Tensor: """ Returns the mean of the distribution. - + Returns: torch.Tensor: The mean (location) parameter `loc`. """ @@ -751,7 +745,7 @@ def mean(self) -> torch.Tensor: def mode(self) -> torch.Tensor: """ Returns the mode of the distribution. - + Returns: torch.Tensor: The mode (equal to `loc` in a Normal distribution). """ @@ -761,7 +755,7 @@ def mode(self) -> torch.Tensor: def stddev(self) -> torch.Tensor: """ Returns the standard deviation of the distribution. - + Returns: torch.Tensor: The standard deviation (scale) parameter `scale`. """ @@ -771,7 +765,7 @@ def stddev(self) -> torch.Tensor: def variance(self) -> torch.Tensor: """ Returns the variance of the distribution. - + Returns: torch.Tensor: The variance, computed as `scale ** 2`. """ @@ -780,7 +774,7 @@ def variance(self) -> torch.Tensor: def entropy(self) -> torch.Tensor: """ Computes the entropy of the distribution. - + Returns: torch.Tensor: The entropy of the Normal distribution, which is a measure of uncertainty. """ @@ -792,13 +786,13 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor: Args: value (torch.Tensor): The value at which to evaluate the CDF. - + Returns: torch.Tensor: The probability that a random variable from the distribution is less than or equal to `value`. """ return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2)))) - def expand(self, batch_shape: torch.Size, _instance=None) -> "Normal": + def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal": """ Expands the distribution parameters to a new batch shape. @@ -823,7 +817,7 @@ def icdf(self, value: torch.Tensor) -> torch.Tensor: Args: value (torch.Tensor): The probability value at which to evaluate the inverse CDF. - + Returns: torch.Tensor: The quantile corresponding to `value`. """ @@ -835,7 +829,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor: Args: value (torch.Tensor): The value at which to evaluate the log probability. - + Returns: torch.Tensor: The log probability density at `value`. """ @@ -849,7 +843,7 @@ def _transform(self, z: torch.Tensor) -> torch.Tensor: Args: z (torch.Tensor): Input tensor to transform. - + Returns: torch.Tensor: The transformed tensor, representing the standardized normal form. """ @@ -864,21 +858,21 @@ def _d_transform_d_z(self) -> torch.Tensor: """ return 1 / self.scale - def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + def sample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: """ Generates a sample from the Normal distribution using `torch.normal`. - + Args: sample_shape (torch.Size): Shape of the sample to generate. - + Returns: torch.Tensor: A tensor with samples from the Normal distribution, detached from the computation graph. """ shape = self._extended_shape(sample_shape) with torch.no_grad(): return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) - - def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: + + def rsample(self, sample_shape: _size = torch.Size) -> torch.Tensor: """ Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation. @@ -893,20 +887,20 @@ def rsample(self, sample_shape: _size = torch.Size()) -> torch.Tensor: surrogate_x = -transform / self._d_transform_d_z().detach() # Return the sample with gradient tracking enabled return x + (surrogate_x - surrogate_x.detach()) - + class MixtureSameFamily(torch.distributions.MixtureSameFamily): """ - MixtureSameFamily is a class that represents a mixture of distributions - from the same family, supporting reparameterized sampling for gradient-based optimization. + Represents a mixture of distributions from the same family. + Supporting reparameterized sampling for gradient-based optimization. """ has_rsample = True def __init__(self, *args, **kwargs) -> None: """ - Initializes the MixtureSameFamily distribution and checks if the component distributions - support reparameterized sampling (required for `rsample`). + Initializes the MixtureSameFamily distribution and checks if the component distributions. + Support reparameterized sampling (required for `rsample`). Raises: ValueError: If the component distributions do not support reparameterized sampling. @@ -925,7 +919,7 @@ def __init__(self, *args, **kwargs) -> None: RelaxedBernoulli, ] - def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: + def rsample(self, sample_shape: torch.Size = torch.Size) -> torch.Tensor: """ Generates a reparameterized sample from the mixture of distributions. @@ -1028,7 +1022,7 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor: univariate_components = self._component_distribution.base_dist else: univariate_components = self._component_distribution - + if callable(getattr(univariate_components, "_log_cdf", None)): log_cdf_x = univariate_components._log_cdf(x) else: