19
19
from torch .distributions .exp_family import ExponentialFamily
20
20
from torch .distributions .utils import broadcast_all , lazy_property
21
21
from torch .types import _size
22
- from torch .distributions .distribution import Distribution
23
22
24
23
default_size = torch .Size ()
25
24
@@ -1041,11 +1040,11 @@ def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
1041
1040
def _eval_poly (y : torch .Tensor , coef : torch .Tensor ) -> torch .Tensor :
1042
1041
"""
1043
1042
Evaluate a polynomial at given points.
1044
-
1043
+
1045
1044
Args:
1046
1045
y: Input tensor.
1047
1046
coeffs: Polynomial coefficients.
1048
-
1047
+
1049
1048
Returns:
1050
1049
Evaluated polynomial tensor.
1051
1050
"""
@@ -1108,7 +1107,7 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
1108
1107
Args:
1109
1108
x: Input tensor, must be positive.
1110
1109
order: Order of the Bessel function (0 or 1).
1111
-
1110
+
1112
1111
Returns:
1113
1112
Logarithm of the Bessel function.
1114
1113
"""
@@ -1133,20 +1132,17 @@ def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
1133
1132
1134
1133
@torch .jit .script_if_tracing
1135
1134
def _rejection_sample (
1136
- loc : torch .Tensor ,
1137
- concentration : torch .Tensor ,
1138
- proposal_r : torch .Tensor ,
1139
- x : torch .Tensor
1135
+ loc : torch .Tensor , concentration : torch .Tensor , proposal_r : torch .Tensor , x : torch .Tensor
1140
1136
) -> torch .Tensor :
1141
1137
"""
1142
1138
Perform rejection sampling for the von Mises distribution.
1143
-
1139
+
1144
1140
Args:
1145
1141
loc: Location parameter.
1146
1142
concentration: Concentration parameter.
1147
1143
proposal_r: Precomputed proposal parameter.
1148
1144
x: Tensor to fill with samples.
1149
-
1145
+
1150
1146
Returns:
1151
1147
Tensor of samples.
1152
1148
"""
@@ -1165,9 +1161,7 @@ def _rejection_sample(
1165
1161
1166
1162
1167
1163
class VonMises (Distribution ):
1168
- """
1169
- Von Mises distribution class for circular data.
1170
- """
1164
+ """Von Mises distribution class for circular data."""
1171
1165
1172
1166
arg_constraints = {
1173
1167
"loc" : constraints .real ,
@@ -1181,33 +1175,37 @@ def __init__(
1181
1175
loc : torch .Tensor ,
1182
1176
concentration : torch .Tensor ,
1183
1177
validate_args : bool = None ,
1184
- ):
1178
+ ) -> None :
1179
+ """
1180
+ Args:
1181
+ loc: loc parameter of the distribution.
1182
+ concentration: concentration parameter of the distribution.
1183
+ validate_args: If True, checks the distribution parameters for validity.
1184
+ """
1185
1185
self .loc , self .concentration = broadcast_all (loc , concentration )
1186
1186
batch_shape = self .loc .shape
1187
1187
super ().__init__ (batch_shape , torch .Size (), validate_args )
1188
-
1188
+
1189
1189
@lazy_property
1190
1190
@torch .no_grad ()
1191
1191
def _proposal_r (self ) -> torch .Tensor :
1192
- """
1193
- Compute the proposal parameter for sampling.
1194
- """
1192
+ """Compute the proposal parameter for sampling."""
1195
1193
kappa = self ._concentration
1196
1194
tau = 1 + (1 + 4 * kappa ** 2 ).sqrt ()
1197
1195
rho = (tau - (2 * tau ).sqrt ()) / (2 * kappa )
1198
1196
_proposal_r = (1 + rho ** 2 ) / (2 * rho )
1199
-
1197
+
1200
1198
# second order Taylor expansion around 0 for small kappa
1201
1199
_proposal_r_taylor = 1 / kappa + kappa
1202
1200
return torch .where (kappa < 1e-5 , _proposal_r_taylor , _proposal_r )
1203
1201
1204
- def log_prob (self , value ) :
1202
+ def log_prob (self , value : torch . Tensor ) -> torch . Tensor :
1205
1203
"""
1206
1204
Compute the log probability of the given value.
1207
1205
1208
1206
Args:
1209
1207
value: Tensor of values.
1210
-
1208
+
1211
1209
Returns:
1212
1210
Tensor of log probabilities.
1213
1211
"""
@@ -1218,15 +1216,15 @@ def log_prob(self, value):
1218
1216
return log_prob
1219
1217
1220
1218
@lazy_property
1221
- def _loc (self ):
1219
+ def _loc (self ) -> torch . Tensor :
1222
1220
return self .loc .to (torch .double )
1223
1221
1224
1222
@lazy_property
1225
- def _concentration (self ):
1223
+ def _concentration (self ) -> torch . Tensor :
1226
1224
return self .concentration .to (torch .double )
1227
-
1225
+
1228
1226
@torch .no_grad ()
1229
- def sample (self , sample_shape = torch .Size ()) :
1227
+ def sample (self , sample_shape : _size = default_size ) -> torch .Tensor :
1230
1228
"""
1231
1229
The sampling algorithm for the von Mises distribution is based on the
1232
1230
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
@@ -1238,33 +1236,33 @@ def sample(self, sample_shape=torch.Size()):
1238
1236
"""
1239
1237
shape = self ._extended_shape (sample_shape )
1240
1238
x = torch .empty (shape , dtype = self ._loc .dtype , device = self .loc .device )
1241
- return _rejection_sample (
1242
- self ._loc , self ._concentration , self ._proposal_r , x
1243
- ).to (self .loc .dtype )
1239
+ return _rejection_sample (self ._loc , self ._concentration , self ._proposal_r , x ).to (self .loc .dtype )
1244
1240
1245
- def rsample (self , sample_shape = torch .Size ()):
1246
- """
1247
- Generate reparameterized samples from the distribution.
1248
- """
1241
+ def rsample (self , sample_shape : _size = default_size ) -> torch .Tensor :
1242
+ """Generate reparameterized samples from the distribution"""
1249
1243
shape = self ._extended_shape (sample_shape )
1250
1244
samples = _VonMisesSampler .apply (self .concentration , self ._proposal_r , shape )
1251
1245
samples = samples + self .loc
1252
-
1246
+
1253
1247
# Map the samples to [-pi, pi].
1254
- return samples - 2. * torch .pi * torch .round (samples / (2. * torch .pi ))
1248
+ return samples - 2.0 * torch .pi * torch .round (samples / (2.0 * torch .pi ))
1255
1249
1256
1250
@property
1257
- def mean (self ):
1251
+ def mean (self ) -> torch . Tensor :
1258
1252
"""Mean of the distribution."""
1259
1253
return self .loc
1260
1254
1261
1255
@property
1262
- def variance (self ):
1256
+ def variance (self ) -> torch . Tensor :
1263
1257
"""Variance of the distribution."""
1264
- return 1 - (
1265
- _log_modified_bessel_fn (self .concentration , order = 1 )
1266
- - _log_modified_bessel_fn (self .concentration , order = 0 )
1267
- ).exp ()
1258
+ return (
1259
+ 1
1260
+ - (
1261
+ _log_modified_bessel_fn (self .concentration , order = 1 )
1262
+ - _log_modified_bessel_fn (self .concentration , order = 0 )
1263
+ ).exp ()
1264
+ )
1265
+
1268
1266
1269
1267
@torch .jit .script_if_tracing
1270
1268
@torch .no_grad ()
@@ -1282,7 +1280,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
1282
1280
"""
1283
1281
x = torch .empty (shape , dtype = concentration .dtype , device = concentration .device )
1284
1282
done = torch .zeros (x .shape , dtype = torch .bool , device = concentration .device )
1285
-
1283
+
1286
1284
while not done .all ():
1287
1285
u = torch .rand ((3 ,) + x .shape , dtype = concentration .dtype , device = concentration .device )
1288
1286
u1 , u2 , u3 = u .unbind ()
@@ -1295,6 +1293,7 @@ def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, sh
1295
1293
done = done | accept
1296
1294
return x
1297
1295
1296
+
1298
1297
def cosxm1 (x : torch .Tensor ) -> torch .Tensor :
1299
1298
"""
1300
1299
Compute cos(x) - 1 using a numerically stable formula.
@@ -1307,6 +1306,7 @@ def cosxm1(x: torch.Tensor) -> torch.Tensor:
1307
1306
"""
1308
1307
return - 2 * torch .square (torch .sin (x / 2.0 ))
1309
1308
1309
+
1310
1310
class _VonMisesSampler (torch .autograd .Function ):
1311
1311
@staticmethod
1312
1312
def forward (
@@ -1329,7 +1329,7 @@ def forward(
1329
1329
"""
1330
1330
samples = _rejection_rsample (concentration , proposal_r , shape )
1331
1331
ctx .save_for_backward (concentration , proposal_r , samples )
1332
-
1332
+
1333
1333
return samples
1334
1334
1335
1335
@staticmethod
@@ -1348,29 +1348,27 @@ def backward(
1348
1348
Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
1349
1349
"""
1350
1350
concentration , proposal_r , samples = ctx .saved_tensors
1351
-
1352
- num_periods = torch .round (samples / (2. * torch .pi ))
1353
- x_mapped = samples - (2. * torch .pi ) * num_periods
1354
-
1355
- ## Parameters from the paper
1351
+
1352
+ num_periods = torch .round (samples / (2.0 * torch .pi ))
1353
+ x_mapped = samples - (2.0 * torch .pi ) * num_periods
1354
+
1355
+ # Parameters from the paper
1356
1356
ck = 10.5
1357
1357
num_terms = 20
1358
-
1359
- ## Compute series and normal approximation
1358
+
1359
+ # Compute series and normal approximation
1360
1360
cdf_series , dcdf_dconcentration_series = von_mises_cdf_series (x_mapped , concentration , num_terms )
1361
1361
cdf_normal , dcdf_dconcentration_normal = von_mises_cdf_normal (x_mapped , concentration )
1362
1362
use_series = concentration < ck
1363
- cdf = torch .where (use_series , cdf_series , cdf_normal ) + num_periods
1363
+ # cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
1364
1364
dcdf_dconcentration = torch .where (use_series , dcdf_dconcentration_series , dcdf_dconcentration_normal )
1365
-
1366
- ## Compute CDF gradient terms
1367
- inv_prob = torch .exp (concentration * cosxm1 (samples )) / (
1368
- 2 * math .pi * torch .special .i0e (concentration )
1369
- )
1370
- grad_concentration = grad_output * (- dcdf_dconcentration / inv_prob )
1371
-
1365
+
1366
+ # Compute CDF gradient terms
1367
+ inv_prob = torch .exp (concentration * cosxm1 (samples )) / (2 * math .pi * torch .special .i0e (concentration ))
1368
+ grad_concentration = grad_output * (- dcdf_dconcentration / inv_prob )
1369
+
1372
1370
return grad_concentration , None , None
1373
-
1371
+
1374
1372
1375
1373
def von_mises_cdf_series (
1376
1374
x : torch .Tensor , concentration : torch .Tensor , num_terms : int
@@ -1394,25 +1392,26 @@ def von_mises_cdf_series(
1394
1392
drn_dconcentration = torch .zeros_like (x )
1395
1393
1396
1394
while n > 0 :
1397
- denominator = 2. * n / concentration + rn
1398
- ddenominator_dk = - 2. * n / concentration ** 2 + drn_dconcentration
1399
- rn = 1. / denominator
1400
- drn_dconcentration = - ddenominator_dk / denominator ** 2
1395
+ denominator = 2.0 * n / concentration + rn
1396
+ ddenominator_dk = - 2.0 * n / concentration ** 2 + drn_dconcentration
1397
+ rn = 1.0 / denominator
1398
+ drn_dconcentration = - ddenominator_dk / denominator ** 2
1401
1399
1402
1400
multiplier = torch .sin (n * x ) / n + vn
1403
1401
vn = rn * multiplier
1404
- dvn_dconcentration = ( drn_dconcentration * multiplier + rn * dvn_dconcentration )
1405
-
1402
+ dvn_dconcentration = drn_dconcentration * multiplier + rn * dvn_dconcentration
1403
+
1406
1404
n -= 1
1407
1405
1408
- cdf = 0.5 + x / (2. * torch .pi ) + vn / torch .pi
1406
+ cdf = 0.5 + x / (2.0 * torch .pi ) + vn / torch .pi
1409
1407
dcdf_dconcentration = dvn_dconcentration / torch .pi
1410
1408
1411
- cdf_clipped = torch .clamp (cdf , 0. , 1. )
1412
- dcdf_dconcentration *= (cdf >= 0. ) & (cdf <= 1. )
1409
+ cdf_clipped = torch .clamp (cdf , 0.0 , 1.0 )
1410
+ dcdf_dconcentration *= (cdf >= 0.0 ) & (cdf <= 1.0 )
1413
1411
1414
1412
return cdf_clipped , dcdf_dconcentration
1415
-
1413
+
1414
+
1416
1415
def cdf_func (concentration : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
1417
1416
"""
1418
1417
Approximate the CDF of the von Mises distribution.
@@ -1424,32 +1423,26 @@ def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
1424
1423
Returns:
1425
1424
torch.Tensor: Approximate CDF values.
1426
1425
"""
1427
-
1428
1426
# Calculate the z value based on the approximation
1429
- z = (torch .sqrt (torch .tensor (2. / torch .pi )) / torch .special .i0e (concentration )) * torch .sin (0.5 * x )
1427
+ z = (torch .sqrt (torch .tensor (2.0 / torch .pi )) / torch .special .i0e (concentration )) * torch .sin (0.5 * x )
1430
1428
# Apply corrections to z to improve the approximation
1431
- z2 = z ** 2
1429
+ z2 = z ** 2
1432
1430
z3 = z2 * z
1433
- z4 = z2 ** 2
1434
- c = 24. * concentration
1435
- c1 = 56.
1431
+ z4 = z2 ** 2
1432
+ c = 24.0 * concentration
1433
+ c1 = 56.0
1436
1434
1437
- xi = z - z3 / (
1438
- ((c - 2. * z2 - 16. ) / 3. ) -
1439
- (z4 + (7. / 4. ) * z2 + 167. / 2. ) / (c - c1 - z2 + 3. )
1440
- ) ** 2
1435
+ xi = z - z3 / (((c - 2.0 * z2 - 16.0 ) / 3.0 ) - (z4 + (7.0 / 4.0 ) * z2 + 167.0 / 2.0 ) / (c - c1 - z2 + 3.0 )) ** 2
1441
1436
1442
1437
# Use the standard normal distribution for the approximation
1443
1438
distrib = torch .distributions .Normal (
1444
- torch .tensor (0. , dtype = x .dtype , device = x .device ),
1445
- torch .tensor (1. , dtype = x .dtype , device = x .device )
1439
+ torch .tensor (0.0 , dtype = x .dtype , device = x .device ), torch .tensor (1.0 , dtype = x .dtype , device = x .device )
1446
1440
)
1447
-
1441
+
1448
1442
return distrib .cdf (xi )
1449
1443
1450
- def von_mises_cdf_normal (
1451
- x : torch .Tensor , concentration : torch .Tensor
1452
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
1444
+
1445
+ def von_mises_cdf_normal (x : torch .Tensor , concentration : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
1453
1446
"""
1454
1447
Compute the CDF of the von Mises distribution using a normal approximation.
1455
1448
@@ -1467,4 +1460,4 @@ def von_mises_cdf_normal(
1467
1460
dcdf_dconcentration = concentration_ .grad .clone () # Copy the gradient
1468
1461
# Detach gradients to prevent further autograd tracking
1469
1462
concentration_ .grad = None
1470
- return cdf , dcdf_dconcentration
1463
+ return cdf , dcdf_dconcentration
0 commit comments