Skip to content

Commit c5f2229

Browse files
authored
MAINT better fused type names in loss module (scikit-learn#27330)
1 parent 8ed76c7 commit c5f2229

File tree

2 files changed

+93
-87
lines changed

2 files changed

+93
-87
lines changed

sklearn/_loss/_loss.pxd

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# cython: language_level=3
22

3-
# Fused types for y_true, y_pred, raw_prediction
4-
ctypedef fused Y_DTYPE_C:
3+
# Fused types for input like y_true, raw_prediction, sample_weights.
4+
ctypedef fused floating_in:
55
double
66
float
77

88

9-
# Fused types for gradient and hessian
10-
ctypedef fused G_DTYPE_C:
9+
# Fused types for output like gradient and hessian
10+
# We use a different fused types for input (floating_in) and output (floating_out), such
11+
# that input and output can have different dtypes in the same function call. A single
12+
# fused type can only take on one single value (type) for all arguments in one function
13+
# call.
14+
ctypedef fused floating_out:
1115
double
1216
float
1317

sklearn/_loss/_loss.pyx.tp

Lines changed: 85 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ cdef inline double log1pexp(double x) noexcept nogil:
268268

269269
cdef inline void sum_exp_minus_max(
270270
const int i,
271-
const Y_DTYPE_C[:, :] raw_prediction, # IN
272-
Y_DTYPE_C *p # OUT
271+
const floating_in[:, :] raw_prediction, # IN
272+
floating_in *p # OUT
273273
) noexcept nogil:
274274
# Thread local buffers are used to stores results of this function via p.
275275
# The results are stored as follows:
@@ -744,7 +744,7 @@ cdef inline double_pair cgrad_hess_half_binomial(
744744
double raw_prediction
745745
) noexcept nogil:
746746
# with y_pred = expit(raw)
747-
# hessian = y_pred * (1 - y_pred) = exp(raw) / (1 + exp(raw))**2
747+
# hessian = y_pred * (1 - y_pred) = exp( raw) / (1 + exp( raw))**2
748748
# = exp(-raw) / (1 + exp(-raw))**2
749749
cdef double_pair gh
750750
gh.val2 = exp(-raw_prediction) # used as temporary
@@ -835,7 +835,9 @@ cdef class CyLossFunction:
835835
"""
836836
pass
837837

838-
cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil:
838+
cdef double_pair cy_grad_hess(
839+
self, double y_true, double raw_prediction
840+
) noexcept nogil:
839841
"""Compute gradient and hessian.
840842

841843
Gradient and hessian of loss w.r.t. raw_prediction for a single sample.
@@ -862,10 +864,10 @@ cdef class CyLossFunction:
862864

863865
def loss(
864866
self,
865-
const Y_DTYPE_C[::1] y_true, # IN
866-
const Y_DTYPE_C[::1] raw_prediction, # IN
867-
const Y_DTYPE_C[::1] sample_weight, # IN
868-
G_DTYPE_C[::1] loss_out, # OUT
867+
const floating_in[::1] y_true, # IN
868+
const floating_in[::1] raw_prediction, # IN
869+
const floating_in[::1] sample_weight, # IN
870+
floating_out[::1] loss_out, # OUT
869871
int n_threads=1
870872
):
871873
"""Compute the pointwise loss value for each input.
@@ -892,10 +894,10 @@ cdef class CyLossFunction:
892894

893895
def gradient(
894896
self,
895-
const Y_DTYPE_C[::1] y_true, # IN
896-
const Y_DTYPE_C[::1] raw_prediction, # IN
897-
const Y_DTYPE_C[::1] sample_weight, # IN
898-
G_DTYPE_C[::1] gradient_out, # OUT
897+
const floating_in[::1] y_true, # IN
898+
const floating_in[::1] raw_prediction, # IN
899+
const floating_in[::1] sample_weight, # IN
900+
floating_out[::1] gradient_out, # OUT
899901
int n_threads=1
900902
):
901903
"""Compute gradient of loss w.r.t raw_prediction for each input.
@@ -922,11 +924,11 @@ cdef class CyLossFunction:
922924

923925
def loss_gradient(
924926
self,
925-
const Y_DTYPE_C[::1] y_true, # IN
926-
const Y_DTYPE_C[::1] raw_prediction, # IN
927-
const Y_DTYPE_C[::1] sample_weight, # IN
928-
G_DTYPE_C[::1] loss_out, # OUT
929-
G_DTYPE_C[::1] gradient_out, # OUT
927+
const floating_in[::1] y_true, # IN
928+
const floating_in[::1] raw_prediction, # IN
929+
const floating_in[::1] sample_weight, # IN
930+
floating_out[::1] loss_out, # OUT
931+
floating_out[::1] gradient_out, # OUT
930932
int n_threads=1
931933
):
932934
"""Compute loss and gradient of loss w.r.t raw_prediction.
@@ -960,11 +962,11 @@ cdef class CyLossFunction:
960962

961963
def gradient_hessian(
962964
self,
963-
const Y_DTYPE_C[::1] y_true, # IN
964-
const Y_DTYPE_C[::1] raw_prediction, # IN
965-
const Y_DTYPE_C[::1] sample_weight, # IN
966-
G_DTYPE_C[::1] gradient_out, # OUT
967-
G_DTYPE_C[::1] hessian_out, # OUT
965+
const floating_in[::1] y_true, # IN
966+
const floating_in[::1] raw_prediction, # IN
967+
const floating_in[::1] sample_weight, # IN
968+
floating_out[::1] gradient_out, # OUT
969+
floating_out[::1] hessian_out, # OUT
968970
int n_threads=1
969971
):
970972
"""Compute gradient and hessian of loss w.r.t raw_prediction.
@@ -1022,10 +1024,10 @@ cdef class {{name}}(CyLossFunction):
10221024

10231025
def loss(
10241026
self,
1025-
const Y_DTYPE_C[::1] y_true, # IN
1026-
const Y_DTYPE_C[::1] raw_prediction, # IN
1027-
const Y_DTYPE_C[::1] sample_weight, # IN
1028-
G_DTYPE_C[::1] loss_out, # OUT
1027+
const floating_in[::1] y_true, # IN
1028+
const floating_in[::1] raw_prediction, # IN
1029+
const floating_in[::1] sample_weight, # IN
1030+
floating_out[::1] loss_out, # OUT
10291031
int n_threads=1
10301032
):
10311033
cdef:
@@ -1048,11 +1050,11 @@ cdef class {{name}}(CyLossFunction):
10481050
{{if closs_grad is not None}}
10491051
def loss_gradient(
10501052
self,
1051-
const Y_DTYPE_C[::1] y_true, # IN
1052-
const Y_DTYPE_C[::1] raw_prediction, # IN
1053-
const Y_DTYPE_C[::1] sample_weight, # IN
1054-
G_DTYPE_C[::1] loss_out, # OUT
1055-
G_DTYPE_C[::1] gradient_out, # OUT
1053+
const floating_in[::1] y_true, # IN
1054+
const floating_in[::1] raw_prediction, # IN
1055+
const floating_in[::1] sample_weight, # IN
1056+
floating_out[::1] loss_out, # OUT
1057+
floating_out[::1] gradient_out, # OUT
10561058
int n_threads=1
10571059
):
10581060
cdef:
@@ -1080,10 +1082,10 @@ cdef class {{name}}(CyLossFunction):
10801082

10811083
def gradient(
10821084
self,
1083-
const Y_DTYPE_C[::1] y_true, # IN
1084-
const Y_DTYPE_C[::1] raw_prediction, # IN
1085-
const Y_DTYPE_C[::1] sample_weight, # IN
1086-
G_DTYPE_C[::1] gradient_out, # OUT
1085+
const floating_in[::1] y_true, # IN
1086+
const floating_in[::1] raw_prediction, # IN
1087+
const floating_in[::1] sample_weight, # IN
1088+
floating_out[::1] gradient_out, # OUT
10871089
int n_threads=1
10881090
):
10891091
cdef:
@@ -1105,11 +1107,11 @@ cdef class {{name}}(CyLossFunction):
11051107

11061108
def gradient_hessian(
11071109
self,
1108-
const Y_DTYPE_C[::1] y_true, # IN
1109-
const Y_DTYPE_C[::1] raw_prediction, # IN
1110-
const Y_DTYPE_C[::1] sample_weight, # IN
1111-
G_DTYPE_C[::1] gradient_out, # OUT
1112-
G_DTYPE_C[::1] hessian_out, # OUT
1110+
const floating_in[::1] y_true, # IN
1111+
const floating_in[::1] raw_prediction, # IN
1112+
const floating_in[::1] sample_weight, # IN
1113+
floating_out[::1] gradient_out, # OUT
1114+
floating_out[::1] hessian_out, # OUT
11131115
int n_threads=1
11141116
):
11151117
cdef:
@@ -1158,18 +1160,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11581160
# opposite are welcome.
11591161
def loss(
11601162
self,
1161-
const Y_DTYPE_C[::1] y_true, # IN
1162-
const Y_DTYPE_C[:, :] raw_prediction, # IN
1163-
const Y_DTYPE_C[::1] sample_weight, # IN
1164-
G_DTYPE_C[::1] loss_out, # OUT
1163+
const floating_in[::1] y_true, # IN
1164+
const floating_in[:, :] raw_prediction, # IN
1165+
const floating_in[::1] sample_weight, # IN
1166+
floating_out[::1] loss_out, # OUT
11651167
int n_threads=1
11661168
):
11671169
cdef:
11681170
int i, k
11691171
int n_samples = y_true.shape[0]
11701172
int n_classes = raw_prediction.shape[1]
1171-
Y_DTYPE_C max_value, sum_exps
1172-
Y_DTYPE_C* p # temporary buffer
1173+
floating_in max_value, sum_exps
1174+
floating_in* p # temporary buffer
11731175

11741176
# We assume n_samples > n_classes. In this case having the inner loop
11751177
# over n_classes is a good default.
@@ -1181,7 +1183,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11811183
with nogil, parallel(num_threads=n_threads):
11821184
# Define private buffer variables as each thread might use its
11831185
# own.
1184-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1186+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
11851187

11861188
for i in prange(n_samples, schedule='static'):
11871189
sum_exp_minus_max(i, raw_prediction, p)
@@ -1197,7 +1199,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
11971199
free(p)
11981200
else:
11991201
with nogil, parallel(num_threads=n_threads):
1200-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1202+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
12011203

12021204
for i in prange(n_samples, schedule='static'):
12031205
sum_exp_minus_max(i, raw_prediction, p)
@@ -1218,26 +1220,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12181220

12191221
def loss_gradient(
12201222
self,
1221-
const Y_DTYPE_C[::1] y_true, # IN
1222-
const Y_DTYPE_C[:, :] raw_prediction, # IN
1223-
const Y_DTYPE_C[::1] sample_weight, # IN
1224-
G_DTYPE_C[::1] loss_out, # OUT
1225-
G_DTYPE_C[:, :] gradient_out, # OUT
1223+
const floating_in[::1] y_true, # IN
1224+
const floating_in[:, :] raw_prediction, # IN
1225+
const floating_in[::1] sample_weight, # IN
1226+
floating_out[::1] loss_out, # OUT
1227+
floating_out[:, :] gradient_out, # OUT
12261228
int n_threads=1
12271229
):
12281230
cdef:
12291231
int i, k
12301232
int n_samples = y_true.shape[0]
12311233
int n_classes = raw_prediction.shape[1]
1232-
Y_DTYPE_C max_value, sum_exps
1233-
Y_DTYPE_C* p # temporary buffer
1234+
floating_in max_value, sum_exps
1235+
floating_in* p # temporary buffer
12341236

12351237
if sample_weight is None:
12361238
# inner loop over n_classes
12371239
with nogil, parallel(num_threads=n_threads):
12381240
# Define private buffer variables as each thread might use its
12391241
# own.
1240-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1242+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
12411243

12421244
for i in prange(n_samples, schedule='static'):
12431245
sum_exp_minus_max(i, raw_prediction, p)
@@ -1256,7 +1258,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12561258
free(p)
12571259
else:
12581260
with nogil, parallel(num_threads=n_threads):
1259-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1261+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
12601262

12611263
for i in prange(n_samples, schedule='static'):
12621264
sum_exp_minus_max(i, raw_prediction, p)
@@ -1280,25 +1282,25 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
12801282

12811283
def gradient(
12821284
self,
1283-
const Y_DTYPE_C[::1] y_true, # IN
1284-
const Y_DTYPE_C[:, :] raw_prediction, # IN
1285-
const Y_DTYPE_C[::1] sample_weight, # IN
1286-
G_DTYPE_C[:, :] gradient_out, # OUT
1285+
const floating_in[::1] y_true, # IN
1286+
const floating_in[:, :] raw_prediction, # IN
1287+
const floating_in[::1] sample_weight, # IN
1288+
floating_out[:, :] gradient_out, # OUT
12871289
int n_threads=1
12881290
):
12891291
cdef:
12901292
int i, k
12911293
int n_samples = y_true.shape[0]
12921294
int n_classes = raw_prediction.shape[1]
1293-
Y_DTYPE_C sum_exps
1294-
Y_DTYPE_C* p # temporary buffer
1295+
floating_in sum_exps
1296+
floating_in* p # temporary buffer
12951297

12961298
if sample_weight is None:
12971299
# inner loop over n_classes
12981300
with nogil, parallel(num_threads=n_threads):
12991301
# Define private buffer variables as each thread might use its
13001302
# own.
1301-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1303+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
13021304

13031305
for i in prange(n_samples, schedule='static'):
13041306
sum_exp_minus_max(i, raw_prediction, p)
@@ -1312,7 +1314,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13121314
free(p)
13131315
else:
13141316
with nogil, parallel(num_threads=n_threads):
1315-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1317+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
13161318

13171319
for i in prange(n_samples, schedule='static'):
13181320
sum_exp_minus_max(i, raw_prediction, p)
@@ -1329,26 +1331,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13291331

13301332
def gradient_hessian(
13311333
self,
1332-
const Y_DTYPE_C[::1] y_true, # IN
1333-
const Y_DTYPE_C[:, :] raw_prediction, # IN
1334-
const Y_DTYPE_C[::1] sample_weight, # IN
1335-
G_DTYPE_C[:, :] gradient_out, # OUT
1336-
G_DTYPE_C[:, :] hessian_out, # OUT
1334+
const floating_in[::1] y_true, # IN
1335+
const floating_in[:, :] raw_prediction, # IN
1336+
const floating_in[::1] sample_weight, # IN
1337+
floating_out[:, :] gradient_out, # OUT
1338+
floating_out[:, :] hessian_out, # OUT
13371339
int n_threads=1
13381340
):
13391341
cdef:
13401342
int i, k
13411343
int n_samples = y_true.shape[0]
13421344
int n_classes = raw_prediction.shape[1]
1343-
Y_DTYPE_C sum_exps
1344-
Y_DTYPE_C* p # temporary buffer
1345+
floating_in sum_exps
1346+
floating_in* p # temporary buffer
13451347

13461348
if sample_weight is None:
13471349
# inner loop over n_classes
13481350
with nogil, parallel(num_threads=n_threads):
13491351
# Define private buffer variables as each thread might use its
13501352
# own.
1351-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1353+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
13521354

13531355
for i in prange(n_samples, schedule='static'):
13541356
sum_exp_minus_max(i, raw_prediction, p)
@@ -1364,7 +1366,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13641366
free(p)
13651367
else:
13661368
with nogil, parallel(num_threads=n_threads):
1367-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1369+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
13681370

13691371
for i in prange(n_samples, schedule='static'):
13701372
sum_exp_minus_max(i, raw_prediction, p)
@@ -1387,26 +1389,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
13871389
# diagonal (in the classes) approximation as implemented above.
13881390
def gradient_proba(
13891391
self,
1390-
const Y_DTYPE_C[::1] y_true, # IN
1391-
const Y_DTYPE_C[:, :] raw_prediction, # IN
1392-
const Y_DTYPE_C[::1] sample_weight, # IN
1393-
G_DTYPE_C[:, :] gradient_out, # OUT
1394-
G_DTYPE_C[:, :] proba_out, # OUT
1392+
const floating_in[::1] y_true, # IN
1393+
const floating_in[:, :] raw_prediction, # IN
1394+
const floating_in[::1] sample_weight, # IN
1395+
floating_out[:, :] gradient_out, # OUT
1396+
floating_out[:, :] proba_out, # OUT
13951397
int n_threads=1
13961398
):
13971399
cdef:
13981400
int i, k
13991401
int n_samples = y_true.shape[0]
14001402
int n_classes = raw_prediction.shape[1]
1401-
Y_DTYPE_C sum_exps
1402-
Y_DTYPE_C* p # temporary buffer
1403+
floating_in sum_exps
1404+
floating_in* p # temporary buffer
14031405

14041406
if sample_weight is None:
14051407
# inner loop over n_classes
14061408
with nogil, parallel(num_threads=n_threads):
14071409
# Define private buffer variables as each thread might use its
14081410
# own.
1409-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1411+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
14101412

14111413
for i in prange(n_samples, schedule='static'):
14121414
sum_exp_minus_max(i, raw_prediction, p)
@@ -1420,7 +1422,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
14201422
free(p)
14211423
else:
14221424
with nogil, parallel(num_threads=n_threads):
1423-
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
1425+
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))
14241426

14251427
for i in prange(n_samples, schedule='static'):
14261428
sum_exp_minus_max(i, raw_prediction, p)

0 commit comments

Comments
 (0)