Skip to content

Commit d341500

Browse files
Merge #1489
1489: Implementation of Focal loss r=darsnack a=shikhargoswami Focal loss was introduced in the RetinaNet paper (https://arxiv.org/pdf/1708.02002.pdf). Focal loss is useful for classification when you we highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The loss value is much high for a sample which is misclassified by the classifier as compared to the loss value corresponding to a well-classified example. Used in single-shot object detection where the imbalance between the background class and other classes is extremely high. Here's it's tensorflow implementation (https://github.com/tensorflow/addons/blob/v0.12.0/tensorflow_addons/losses/focal_loss.py#L26-L81) ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Shikhar Goswami <shikhargoswami2308@gmail.com> Co-authored-by: Shikhar Goswami <44720861+shikhargoswami@users.noreply.github.com>
2 parents 7e9a180 + 284425b commit d341500

File tree

6 files changed

+125
-3
lines changed

6 files changed

+125
-3
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## v0.12.0
44

5+
* Added [Focal Loss function](https://github.com/FluxML/Flux.jl/pull/1489) to Losses module
56
* The Dense layer now supports inputs with [multiple batch dimensions](https://github.com/FluxML/Flux.jl/pull/1405).
67
* Dense and Conv layers no longer perform [implicit type conversion](https://github.com/FluxML/Flux.jl/pull/1394).
78
* Excise datasets in favour of other providers in the julia ecosystem.

docs/src/models/losses.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,6 @@ Flux.Losses.hinge_loss
3939
Flux.Losses.squared_hinge_loss
4040
Flux.Losses.dice_coeff_loss
4141
Flux.Losses.tversky_loss
42+
Flux.Losses.binary_focal_loss
43+
Flux.Losses.focal_loss
4244
```

src/losses/Losses.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ export mse, mae, msle,
1818
dice_coeff_loss,
1919
poisson_loss,
2020
hinge_loss, squared_hinge_loss,
21-
ctc_loss
21+
ctc_loss,
22+
binary_focal_loss, focal_loss
2223

2324
include("utils.jl")
2425
include("functions.jl")

src/losses/functions.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,82 @@ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
429429
1 - num / den
430430
end
431431

432+
"""
433+
binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=eps(ŷ))
434+
435+
Return the [binary_focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
436+
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).
437+
438+
For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).
439+
440+
# Example
441+
```jldoctest
442+
julia> y = [0 1 0
443+
1 0 1]
444+
2×3 Array{Int64,2}:
445+
0 1 0
446+
1 0 1
447+
448+
julia> ŷ = [0.268941 0.5 0.268941
449+
0.731059 0.5 0.731059]
450+
2×3 Array{Float64,2}:
451+
0.268941 0.5 0.268941
452+
0.731059 0.5 0.731059
453+
454+
julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
455+
true
456+
```
457+
458+
See also: [`Losses.focal_loss`](@ref) for multi-class setting
459+
460+
"""
461+
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
462+
=.+ ϵ
463+
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
464+
ce = -log.(p_t)
465+
weight = (1 .- p_t) .^ γ
466+
loss = weight .* ce
467+
agg(loss)
468+
end
469+
470+
"""
471+
focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ))
432472
473+
Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf)
474+
which can be used in classification tasks with highly imbalanced classes.
475+
It down-weights well-classified examples and focuses on hard examples.
476+
The input, 'ŷ', is expected to be normalized (i.e. [`softmax`](@ref) output).
477+
478+
The modulating factor, `γ`, controls the down-weighting strength.
479+
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref).
480+
481+
# Example
482+
```jldoctest
483+
julia> y = [1 0 0 0 1
484+
0 1 0 1 0
485+
0 0 1 0 0]
486+
3×5 Array{Int64,2}:
487+
1 0 0 0 1
488+
0 1 0 1 0
489+
0 0 1 0 0
490+
491+
julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
492+
3×5 Array{Float32,2}:
493+
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
494+
0.244728 0.244728 0.244728 0.244728 0.244728
495+
0.665241 0.665241 0.665241 0.665241 0.665241
496+
497+
julia> Flux.focal_loss(ŷ, y) ≈ 1.1277571935622628
498+
true
499+
```
500+
501+
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
502+
503+
"""
504+
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
505+
=.+ ϵ
506+
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
507+
end
433508
```@meta
434509
DocTestFilters = nothing
435510
```

test/cuda/losses.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy
1+
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss
22

33

44
@testset "Losses" begin
@@ -14,6 +14,17 @@ y = [1, 1, 0.]
1414
@test binarycrossentropy(σ.(x), y) binarycrossentropy(gpu(σ.(x)), gpu(y))
1515
@test logitbinarycrossentropy(x, y) logitbinarycrossentropy(gpu(x), gpu(y))
1616

17+
x = [0.268941 0.5 0.268941
18+
0.731059 0.5 0.731059]
19+
y = [0 1 0
20+
1 0 1]
21+
@test binary_focal_loss(x, y) binary_focal_loss(gpu(x), gpu(y))
22+
23+
x = softmax(reshape(-7:7, 3, 5) .* 1f0)
24+
y = [1 0 0 0 1
25+
0 1 0 1 0
26+
0 0 1 0 0]
27+
@test focal_loss(x, y) focal_loss(gpu(x), gpu(y))
1728

1829
@testset "GPU grad tests" begin
1930
x = rand(Float32, 3,3)

test/losses.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
1313
Flux.Losses.tversky_loss,
1414
Flux.Losses.dice_coeff_loss,
1515
Flux.Losses.poisson_loss,
16-
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss]
16+
Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss,
17+
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss]
1718

1819

1920
@testset "xlogx & xlogy" begin
@@ -174,3 +175,34 @@ end
174175
end
175176
end
176177
end
178+
179+
@testset "binary_focal_loss" begin
180+
y = [0 1 0
181+
1 0 1]
182+
ŷ = [0.268941 0.5 0.268941
183+
0.731059 0.5 0.731059]
184+
185+
y1 = [1 0
186+
0 1]
187+
ŷ1 = [0.6 0.3
188+
0.4 0.7]
189+
@test Flux.binary_focal_loss(ŷ, y) 0.0728675615927385
190+
@test Flux.binary_focal_loss(ŷ1, y1) 0.05691642237852222
191+
@test Flux.binary_focal_loss(ŷ, y; γ=0.0) Flux.binarycrossentropy(ŷ, y)
192+
end
193+
194+
@testset "focal_loss" begin
195+
y = [1 0 0 0 1
196+
0 1 0 1 0
197+
0 0 1 0 0]
198+
ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
199+
y1 = [1 0
200+
0 0
201+
0 1]
202+
ŷ1 = [0.4 0.2
203+
0.5 0.5
204+
0.1 0.3]
205+
@test Flux.focal_loss(ŷ, y) 1.1277571935622628
206+
@test Flux.focal_loss(ŷ1, y1) 0.45990566879720157
207+
@test Flux.focal_loss(ŷ, y; γ=0.0) Flux.crossentropy(ŷ, y)
208+
end

0 commit comments

Comments
 (0)