Skip to content

Commit 7a5d1fa

Browse files
committed
rename match_sizes
1 parent 3525a87 commit 7a5d1fa

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

src/losses/functions.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ julia> Flux.mae(y_model, 1:3)
1919
```
2020
"""
2121
function mae(ŷ, y; agg = mean)
22-
match_sizes(ŷ, y)
22+
_check_sizes(ŷ, y)
2323
agg(abs.(ŷ .- y))
2424
end
2525

@@ -43,7 +43,7 @@ julia> Flux.mse(y_model, y_true)
4343
```
4444
"""
4545
function mse(ŷ, y; agg = mean)
46-
match_sizes(ŷ, y)
46+
_check_sizes(ŷ, y)
4747
agg((ŷ .- y) .^ 2)
4848
end
4949

@@ -67,7 +67,7 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
6767
```
6868
"""
6969
function msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
70-
match_sizes(ŷ, y)
70+
_check_sizes(ŷ, y)
7171
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
7272
end
7373

@@ -82,7 +82,7 @@ given the prediction `ŷ` and true values `y`.
8282
| δ * (|ŷ - y| - 0.5 * δ), otherwise
8383
"""
8484
function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1))
85-
match_sizes(ŷ, y)
85+
_check_sizes(ŷ, y)
8686
abs_error = abs.(ŷ .- y)
8787
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
8888
temp = Zygote.dropgrad(abs_error .< δ)
@@ -212,7 +212,7 @@ julia> Flux.crossentropy(y_model, y_smooth)
212212
```
213213
"""
214214
function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
215-
match_sizes(ŷ, y)
215+
_check_sizes(ŷ, y)
216216
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
217217
end
218218

@@ -251,7 +251,7 @@ julia> Flux.crossentropy(softmax(y_model), y_label)
251251
```
252252
"""
253253
function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
254-
match_sizes(ŷ, y)
254+
_check_sizes(ŷ, y)
255255
agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims))
256256
end
257257

@@ -300,7 +300,7 @@ julia> Flux.crossentropy(y_prob, y_hot)
300300
```
301301
"""
302302
function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
303-
match_sizes(ŷ, y)
303+
_check_sizes(ŷ, y)
304304
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 -+ ϵ)))
305305
end
306306

@@ -330,7 +330,7 @@ julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin)
330330
```
331331
"""
332332
function logitbinarycrossentropy(ŷ, y; agg = mean)
333-
match_sizes(ŷ, y)
333+
_check_sizes(ŷ, y)
334334
agg(@.((1 - y) *- logσ(ŷ)))
335335
end
336336

@@ -370,7 +370,7 @@ Inf
370370
```
371371
"""
372372
function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
373-
match_sizes(ŷ, y)
373+
_check_sizes(ŷ, y)
374374
entropy = agg(sum(xlogx.(y), dims = dims))
375375
cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ)
376376
return entropy + cross_entropy
@@ -385,7 +385,7 @@ end
385385
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
386386
"""
387387
function poisson_loss(ŷ, y; agg = mean)
388-
match_sizes(ŷ, y)
388+
_check_sizes(ŷ, y)
389389
agg(ŷ .- xlogy.(y, ŷ))
390390
end
391391

@@ -399,7 +399,7 @@ prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
399399
See also: [`squared_hinge_loss`](@ref)
400400
"""
401401
function hinge_loss(ŷ, y; agg = mean)
402-
match_sizes(ŷ, y)
402+
_check_sizes(ŷ, y)
403403
agg(max.(0, 1 .-.* y))
404404
end
405405

@@ -412,7 +412,7 @@ Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y
412412
See also: [`hinge_loss`](@ref)
413413
"""
414414
function squared_hinge_loss(ŷ, y; agg = mean)
415-
match_sizes(ŷ, y)
415+
_check_sizes(ŷ, y)
416416
agg((max.(0, 1 .-.* y)) .^ 2)
417417
end
418418

@@ -427,7 +427,7 @@ Similar to the F1_score. Calculated as:
427427
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)
428428
"""
429429
function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0))
430-
match_sizes(ŷ, y)
430+
_check_sizes(ŷ, y)
431431
1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg
432432
end
433433

@@ -441,7 +441,7 @@ Calculated as:
441441
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
442442
"""
443443
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
444-
match_sizes(ŷ, y)
444+
_check_sizes(ŷ, y)
445445
#TODO add agg
446446
num = sum(y .* ŷ) + 1
447447
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
@@ -478,7 +478,7 @@ See also: [`Losses.focal_loss`](@ref) for multi-class setting
478478
479479
"""
480480
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
481-
match_sizes(ŷ, y)
481+
_check_sizes(ŷ, y)
482482
=.+ ϵ
483483
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
484484
ce = -log.(p_t)
@@ -522,7 +522,7 @@ See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
522522
523523
"""
524524
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
525-
match_sizes(ŷ, y)
525+
_check_sizes(ŷ, y)
526526
=.+ ϵ
527527
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
528528
end

src/losses/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ end
2424
end
2525

2626
# This can be made an error in Flux v0.13, for now just a warning
27-
function match_sizes(ŷ::AbstractArray, y::AbstractArray)
27+
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
2828
for d in 1:max(ndims(ŷ), ndims(y))
2929
if size(ŷ,d) != size(y,d)
3030
@error "size mismatch in loss function! In future this will be an error; in Flux 0.12 broadcasting acceps some mismatches" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y))
3131
end
3232
end
3333
end
34-
match_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
34+
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3535

36-
Zygote.@nograd match_sizes
36+
Zygote.@nograd _check_sizes

0 commit comments

Comments
 (0)