Skip to content

Commit 8162b8a

Browse files
bors[bot]mcabbott
andauthored
Merge #1636
1636: Add warnings for mismatched sizes in losses r=mcabbott a=mcabbott Closes #1599, I think, by making loss functions give a warning if the sizes don't match: ```julia julia> mse([1,0], [1 0 0]) ┌ Error: size mismatch in loss function! In future this will be an error; in Flux 0.12 broadcasting acceps some mismatches │ summary(ŷ) = "2-element Vector{Int64}" │ summary(y) = "1×3 Matrix{Int64}" └ @ Flux.Losses ~/.julia/dev/Flux/src/losses/utils.jl:29 0.5 julia> @Btime gradient(sum∘mse, $(rand(10,100)), $(rand(10,100))); 19.709 μs (130 allocations: 51.25 KiB) 19.625 μs (130 allocations: 51.25 KiB) ``` Appears to have no effect on speed, although Zygote is weird and maybe someone has a better test of that. Edit -- closes #1522, too. Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
2 parents 46b73a8 + 6e57ae8 commit 8162b8a

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

src/losses/functions.jl

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ julia> Flux.mae(y_model, 1:3)
1818
0.10000000000000009
1919
```
2020
"""
21-
mae(ŷ, y; agg = mean) = agg(abs.(ŷ .- y))
21+
function mae(ŷ, y; agg = mean)
22+
_check_sizes(ŷ, y)
23+
agg(abs.(ŷ .- y))
24+
end
2225

2326
"""
2427
mse(ŷ, y; agg = mean)
@@ -39,7 +42,10 @@ julia> Flux.mse(y_model, y_true)
3942
0.010000000000000018
4043
```
4144
"""
42-
mse(ŷ, y; agg = mean) = agg((ŷ .- y) .^ 2)
45+
function mse(ŷ, y; agg = mean)
46+
_check_sizes(ŷ, y)
47+
agg((ŷ .- y) .^ 2)
48+
end
4349

4450
"""
4551
msle(ŷ, y; agg = mean, ϵ = eps(ŷ))
@@ -60,8 +66,10 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
6066
0.011100831f0
6167
```
6268
"""
63-
msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) =
69+
function msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
70+
_check_sizes(ŷ, y)
6471
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
72+
end
6573

6674
"""
6775
huber_loss(ŷ, y; δ = 1, agg = mean)
@@ -74,6 +82,7 @@ given the prediction `ŷ` and true values `y`.
7482
| δ * (|ŷ - y| - 0.5 * δ), otherwise
7583
"""
7684
function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1))
85+
_check_sizes(ŷ, y)
7786
abs_error = abs.(ŷ .- y)
7887
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
7988
temp = Zygote.dropgrad(abs_error .< δ)
@@ -203,7 +212,8 @@ julia> Flux.crossentropy(y_model, y_smooth)
203212
```
204213
"""
205214
function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
206-
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
215+
_check_sizes(ŷ, y)
216+
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
207217
end
208218

209219
"""
@@ -241,7 +251,8 @@ julia> Flux.crossentropy(softmax(y_model), y_label)
241251
```
242252
"""
243253
function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
244-
agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims))
254+
_check_sizes(ŷ, y)
255+
agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims))
245256
end
246257

247258
"""
@@ -289,7 +300,8 @@ julia> Flux.crossentropy(y_prob, y_hot)
289300
```
290301
"""
291302
function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
292-
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 -+ ϵ)))
303+
_check_sizes(ŷ, y)
304+
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 -+ ϵ)))
293305
end
294306

295307
"""
@@ -318,7 +330,8 @@ julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin)
318330
```
319331
"""
320332
function logitbinarycrossentropy(ŷ, y; agg = mean)
321-
agg(@.((1 - y) *- logσ(ŷ)))
333+
_check_sizes(ŷ, y)
334+
agg(@.((1 - y) *- logσ(ŷ)))
322335
end
323336

324337
"""
@@ -357,6 +370,7 @@ Inf
357370
```
358371
"""
359372
function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
373+
_check_sizes(ŷ, y)
360374
entropy = agg(sum(xlogx.(y), dims = dims))
361375
cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ)
362376
return entropy + cross_entropy
@@ -370,7 +384,10 @@ end
370384
371385
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
372386
"""
373-
poisson_loss(ŷ, y; agg = mean) = agg(ŷ .- xlogy.(y, ŷ))
387+
function poisson_loss(ŷ, y; agg = mean)
388+
_check_sizes(ŷ, y)
389+
agg(ŷ .- xlogy.(y, ŷ))
390+
end
374391

375392
"""
376393
hinge_loss(ŷ, y; agg = mean)
@@ -381,8 +398,10 @@ prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
381398
382399
See also: [`squared_hinge_loss`](@ref)
383400
"""
384-
hinge_loss(ŷ, y; agg = mean) =
401+
function hinge_loss(ŷ, y; agg = mean)
402+
_check_sizes(ŷ, y)
385403
agg(max.(0, 1 .-.* y))
404+
end
386405

387406
"""
388407
squared_hinge_loss(ŷ, y)
@@ -392,8 +411,10 @@ Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y
392411
393412
See also: [`hinge_loss`](@ref)
394413
"""
395-
squared_hinge_loss(ŷ, y; agg = mean) =
414+
function squared_hinge_loss(ŷ, y; agg = mean)
415+
_check_sizes(ŷ, y)
396416
agg((max.(0, 1 .-.* y)) .^ 2)
417+
end
397418

398419
"""
399420
dice_coeff_loss(ŷ, y; smooth = 1)
@@ -405,8 +426,10 @@ Similar to the F1_score. Calculated as:
405426
406427
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)
407428
"""
408-
dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0)) =
429+
function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0))
430+
_check_sizes(ŷ, y)
409431
1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg
432+
end
410433

411434
"""
412435
tversky_loss(ŷ, y; β = 0.7)
@@ -418,6 +441,7 @@ Calculated as:
418441
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
419442
"""
420443
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
444+
_check_sizes(ŷ, y)
421445
#TODO add agg
422446
num = sum(y .* ŷ) + 1
423447
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
@@ -454,6 +478,7 @@ See also: [`Losses.focal_loss`](@ref) for multi-class setting
454478
455479
"""
456480
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
481+
_check_sizes(ŷ, y)
457482
=.+ ϵ
458483
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
459484
ce = -log.(p_t)
@@ -497,9 +522,11 @@ See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
497522
498523
"""
499524
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
525+
_check_sizes(ŷ, y)
500526
=.+ ϵ
501527
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
502528
end
529+
503530
```@meta
504531
DocTestFilters = nothing
505532
```

src/losses/utils.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""
2-
xlogx(x)
2+
xlogx(x)
33
4-
Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit.
4+
Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero.
55
"""
66
function xlogx(x)
77
result = x * log(x)
88
ifelse(iszero(x), zero(result), result)
99
end
1010

1111
"""
12-
xlogy(x, y)
12+
xlogy(x, y)
1313
14-
Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
14+
Return `x * log(y)` for `y > 0`, and zero when `x == 0`.
1515
"""
1616
function xlogy(x, y)
1717
result = x * log(y)
@@ -22,3 +22,15 @@ end
2222
res = xlogy.(x, y)
2323
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
2424
end
25+
26+
# This can be made an error in Flux v0.13, for now just a warning
27+
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
28+
for d in 1:max(ndims(ŷ), ndims(y))
29+
if size(ŷ,d) != size(y,d)
30+
@warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y))
31+
end
32+
end
33+
end
34+
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
35+
36+
Zygote.@nograd _check_sizes

test/losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ yls = y.*(1-2sf).+sf
104104
@test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ))
105105
@test binarycrossentropy(σ.(logŷ), y; ϵ=0) mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ)))
106106
@test binarycrossentropy(σ.(logŷ), y) mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ))))
107+
@test binarycrossentropy([0.1,0.2,0.9], 1) -mean(log, [0.1,0.2,0.9]) # constant label
107108
end
108109

109110
@testset "logitbinarycrossentropy" begin

0 commit comments

Comments
 (0)