Skip to content

Commit fef046e

Browse files
Fixed bugs in codestyle
1 parent ca829ba commit fef046e

File tree

1 file changed

+78
-85
lines changed

1 file changed

+78
-85
lines changed

src/irt/distributions.py

Lines changed: 78 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch.distributions.exp_family import ExponentialFamily
2020
from torch.distributions.utils import broadcast_all, lazy_property
2121
from torch.types import _size
22-
from torch.distributions.distribution import Distribution
2322

2423
default_size = torch.Size()
2524

@@ -1041,11 +1040,11 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
10411040
def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
10421041
"""
10431042
Evaluate a polynomial at given points.
1044-
1043+
10451044
Args:
10461045
y: Input tensor.
10471046
coeffs: Polynomial coefficients.
1048-
1047+
10491048
Returns:
10501049
Evaluated polynomial tensor.
10511050
"""
@@ -1108,7 +1107,7 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
11081107
Args:
11091108
x: Input tensor, must be positive.
11101109
order: Order of the Bessel function (0 or 1).
1111-
1110+
11121111
Returns:
11131112
Logarithm of the Bessel function.
11141113
"""
@@ -1133,20 +1132,17 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
11331132

11341133
@torch.jit.script_if_tracing
11351134
def _rejection_sample(
1136-
loc: torch.Tensor,
1137-
concentration: torch.Tensor,
1138-
proposal_r: torch.Tensor,
1139-
x: torch.Tensor
1135+
loc: torch.Tensor, concentration: torch.Tensor, proposal_r: torch.Tensor, x: torch.Tensor
11401136
) -> torch.Tensor:
11411137
"""
11421138
Perform rejection sampling for the von Mises distribution.
1143-
1139+
11441140
Args:
11451141
loc: Location parameter.
11461142
concentration: Concentration parameter.
11471143
proposal_r: Precomputed proposal parameter.
11481144
x: Tensor to fill with samples.
1149-
1145+
11501146
Returns:
11511147
Tensor of samples.
11521148
"""
@@ -1165,9 +1161,7 @@ def _rejection_sample(
11651161

11661162

11671163
class VonMises(Distribution):
1168-
"""
1169-
Von Mises distribution class for circular data.
1170-
"""
1164+
"""Von Mises distribution class for circular data."""
11711165

11721166
arg_constraints = {
11731167
"loc": constraints.real,
@@ -1181,33 +1175,37 @@ def __init__(
11811175
loc: torch.Tensor,
11821176
concentration: torch.Tensor,
11831177
validate_args: bool = None,
1184-
):
1178+
) -> None:
1179+
"""
1180+
Args:
1181+
loc: loc parameter of the distribution.
1182+
concentration: concentration parameter of the distribution.
1183+
validate_args: If True, checks the distribution parameters for validity.
1184+
"""
11851185
self.loc, self.concentration = broadcast_all(loc, concentration)
11861186
batch_shape = self.loc.shape
11871187
super().__init__(batch_shape, torch.Size(), validate_args)
1188-
1188+
11891189
@lazy_property
11901190
@torch.no_grad()
11911191
def _proposal_r(self) -> torch.Tensor:
1192-
"""
1193-
Compute the proposal parameter for sampling.
1194-
"""
1192+
"""Compute the proposal parameter for sampling."""
11951193
kappa = self._concentration
11961194
tau = 1 + (1 + 4 * kappa**2).sqrt()
11971195
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
11981196
_proposal_r = (1 + rho**2) / (2 * rho)
1199-
1197+
12001198
# second order Taylor expansion around 0 for small kappa
12011199
_proposal_r_taylor = 1 / kappa + kappa
12021200
return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
12031201

1204-
def log_prob(self, value):
1202+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
12051203
"""
12061204
Compute the log probability of the given value.
12071205
12081206
Args:
12091207
value: Tensor of values.
1210-
1208+
12111209
Returns:
12121210
Tensor of log probabilities.
12131211
"""
@@ -1218,15 +1216,15 @@ def log_prob(self, value):
12181216
return log_prob
12191217

12201218
@lazy_property
1221-
def _loc(self):
1219+
def _loc(self) -> torch.Tensor:
12221220
return self.loc.to(torch.double)
12231221

12241222
@lazy_property
1225-
def _concentration(self):
1223+
def _concentration(self) -> torch.Tensor:
12261224
return self.concentration.to(torch.double)
1227-
1225+
12281226
@torch.no_grad()
1229-
def sample(self, sample_shape=torch.Size()):
1227+
def sample(self, sample_shape: _size = default_size) -> torch.Tensor:
12301228
"""
12311229
The sampling algorithm for the von Mises distribution is based on the
12321230
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
@@ -1238,33 +1236,33 @@ def sample(self, sample_shape=torch.Size()):
12381236
"""
12391237
shape = self._extended_shape(sample_shape)
12401238
x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
1241-
return _rejection_sample(
1242-
self._loc, self._concentration, self._proposal_r, x
1243-
).to(self.loc.dtype)
1239+
return _rejection_sample(self._loc, self._concentration, self._proposal_r, x).to(self.loc.dtype)
12441240

1245-
def rsample(self, sample_shape=torch.Size()):
1246-
"""
1247-
Generate reparameterized samples from the distribution.
1248-
"""
1241+
def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
1242+
"""Generate reparameterized samples from the distribution"""
12491243
shape = self._extended_shape(sample_shape)
12501244
samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape)
12511245
samples = samples + self.loc
1252-
1246+
12531247
# Map the samples to [-pi, pi].
1254-
return samples - 2. * torch.pi * torch.round(samples / (2. * torch.pi))
1248+
return samples - 2.0 * torch.pi * torch.round(samples / (2.0 * torch.pi))
12551249

12561250
@property
1257-
def mean(self):
1251+
def mean(self) -> torch.Tensor:
12581252
"""Mean of the distribution."""
12591253
return self.loc
12601254

12611255
@property
1262-
def variance(self):
1256+
def variance(self) -> torch.Tensor:
12631257
"""Variance of the distribution."""
1264-
return 1 - (
1265-
_log_modified_bessel_fn(self.concentration, order=1)
1266-
- _log_modified_bessel_fn(self.concentration, order=0)
1267-
).exp()
1258+
return (
1259+
1
1260+
- (
1261+
_log_modified_bessel_fn(self.concentration, order=1)
1262+
- _log_modified_bessel_fn(self.concentration, order=0)
1263+
).exp()
1264+
)
1265+
12681266

12691267
@torch.jit.script_if_tracing
12701268
@torch.no_grad()
@@ -1282,7 +1280,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
12821280
"""
12831281
x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device)
12841282
done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device)
1285-
1283+
12861284
while not done.all():
12871285
u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device)
12881286
u1, u2, u3 = u.unbind()
@@ -1295,6 +1293,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
12951293
done = done | accept
12961294
return x
12971295

1296+
12981297
def cosxm1(x: torch.Tensor) -> torch.Tensor:
12991298
"""
13001299
Compute cos(x) - 1 using a numerically stable formula.
@@ -1307,6 +1306,7 @@ def cosxm1(x: torch.Tensor) -> torch.Tensor:
13071306
"""
13081307
return -2 * torch.square(torch.sin(x / 2.0))
13091308

1309+
13101310
class _VonMisesSampler(torch.autograd.Function):
13111311
@staticmethod
13121312
def forward(
@@ -1329,7 +1329,7 @@ def forward(
13291329
"""
13301330
samples = _rejection_rsample(concentration, proposal_r, shape)
13311331
ctx.save_for_backward(concentration, proposal_r, samples)
1332-
1332+
13331333
return samples
13341334

13351335
@staticmethod
@@ -1348,29 +1348,27 @@ def backward(
13481348
Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
13491349
"""
13501350
concentration, proposal_r, samples = ctx.saved_tensors
1351-
1352-
num_periods = torch.round(samples / (2. * torch.pi))
1353-
x_mapped = samples - (2. * torch.pi) * num_periods
1354-
1355-
## Parameters from the paper
1351+
1352+
num_periods = torch.round(samples / (2.0 * torch.pi))
1353+
x_mapped = samples - (2.0 * torch.pi) * num_periods
1354+
1355+
# Parameters from the paper
13561356
ck = 10.5
13571357
num_terms = 20
1358-
1359-
## Compute series and normal approximation
1358+
1359+
# Compute series and normal approximation
13601360
cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms)
13611361
cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration)
13621362
use_series = concentration < ck
1363-
cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
1363+
# cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
13641364
dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal)
1365-
1366-
## Compute CDF gradient terms
1367-
inv_prob = torch.exp(concentration * cosxm1(samples)) / (
1368-
2 * math.pi * torch.special.i0e(concentration)
1369-
)
1370-
grad_concentration = grad_output*(-dcdf_dconcentration / inv_prob)
1371-
1365+
1366+
# Compute CDF gradient terms
1367+
inv_prob = torch.exp(concentration * cosxm1(samples)) / (2 * math.pi * torch.special.i0e(concentration))
1368+
grad_concentration = grad_output * (-dcdf_dconcentration / inv_prob)
1369+
13721370
return grad_concentration, None, None
1373-
1371+
13741372

13751373
def von_mises_cdf_series(
13761374
x: torch.Tensor, concentration: torch.Tensor, num_terms: int
@@ -1394,25 +1392,26 @@ def von_mises_cdf_series(
13941392
drn_dconcentration = torch.zeros_like(x)
13951393

13961394
while n > 0:
1397-
denominator = 2. * n / concentration + rn
1398-
ddenominator_dk = -2. * n / concentration ** 2 + drn_dconcentration
1399-
rn = 1. / denominator
1400-
drn_dconcentration = -ddenominator_dk / denominator ** 2
1395+
denominator = 2.0 * n / concentration + rn
1396+
ddenominator_dk = -2.0 * n / concentration**2 + drn_dconcentration
1397+
rn = 1.0 / denominator
1398+
drn_dconcentration = -ddenominator_dk / denominator**2
14011399

14021400
multiplier = torch.sin(n * x) / n + vn
14031401
vn = rn * multiplier
1404-
dvn_dconcentration = (drn_dconcentration * multiplier + rn * dvn_dconcentration)
1405-
1402+
dvn_dconcentration = drn_dconcentration * multiplier + rn * dvn_dconcentration
1403+
14061404
n -= 1
14071405

1408-
cdf = 0.5 + x / (2. * torch.pi) + vn / torch.pi
1406+
cdf = 0.5 + x / (2.0 * torch.pi) + vn / torch.pi
14091407
dcdf_dconcentration = dvn_dconcentration / torch.pi
14101408

1411-
cdf_clipped = torch.clamp(cdf, 0., 1.)
1412-
dcdf_dconcentration *= (cdf >= 0.) & (cdf <= 1.)
1409+
cdf_clipped = torch.clamp(cdf, 0.0, 1.0)
1410+
dcdf_dconcentration *= (cdf >= 0.0) & (cdf <= 1.0)
14131411

14141412
return cdf_clipped, dcdf_dconcentration
1415-
1413+
1414+
14161415
def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
14171416
"""
14181417
Approximate the CDF of the von Mises distribution.
@@ -1424,32 +1423,26 @@ def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
14241423
Returns:
14251424
torch.Tensor: Approximate CDF values.
14261425
"""
1427-
14281426
# Calculate the z value based on the approximation
1429-
z = (torch.sqrt(torch.tensor(2. / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x)
1427+
z = (torch.sqrt(torch.tensor(2.0 / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x)
14301428
# Apply corrections to z to improve the approximation
1431-
z2 = z ** 2
1429+
z2 = z**2
14321430
z3 = z2 * z
1433-
z4 = z2 ** 2
1434-
c = 24. * concentration
1435-
c1 = 56.
1431+
z4 = z2**2
1432+
c = 24.0 * concentration
1433+
c1 = 56.0
14361434

1437-
xi = z - z3 / (
1438-
((c - 2. * z2 - 16.) / 3.) -
1439-
(z4 + (7. / 4.) * z2 + 167. / 2.) / (c - c1 - z2 + 3.)
1440-
) ** 2
1435+
xi = z - z3 / (((c - 2.0 * z2 - 16.0) / 3.0) - (z4 + (7.0 / 4.0) * z2 + 167.0 / 2.0) / (c - c1 - z2 + 3.0)) ** 2
14411436

14421437
# Use the standard normal distribution for the approximation
14431438
distrib = torch.distributions.Normal(
1444-
torch.tensor(0., dtype=x.dtype, device=x.device),
1445-
torch.tensor(1., dtype=x.dtype, device=x.device)
1439+
torch.tensor(0.0, dtype=x.dtype, device=x.device), torch.tensor(1.0, dtype=x.dtype, device=x.device)
14461440
)
1447-
1441+
14481442
return distrib.cdf(xi)
14491443

1450-
def von_mises_cdf_normal(
1451-
x: torch.Tensor, concentration: torch.Tensor
1452-
) -> Tuple[torch.Tensor, torch.Tensor]:
1444+
1445+
def von_mises_cdf_normal(x: torch.Tensor, concentration: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
14531446
"""
14541447
Compute the CDF of the von Mises distribution using a normal approximation.
14551448
@@ -1467,4 +1460,4 @@ def von_mises_cdf_normal(
14671460
dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient
14681461
# Detach gradients to prevent further autograd tracking
14691462
concentration_.grad = None
1470-
return cdf, dcdf_dconcentration
1463+
return cdf, dcdf_dconcentration

0 commit comments

Comments
 (0)