Skip to content

Commit dcbdf07

Browse files
committed
Fix the implementation and docstring of tversky_loss
1 parent 0c7bdb9 commit dcbdf07

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/losses/functions.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -511,31 +511,31 @@ end
511511
512512
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
513513
Used with imbalanced data to give more weight to false negatives.
514-
Larger β weigh recall more than precision (by placing more emphasis on false negatives)
514+
Larger β weigh recall more than precision (by placing more emphasis on false negatives).
515515
Calculated as:
516516
517-
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
517+
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)
518518
519519
# Example
520520
```jldoctest
521-
julia> = [1, 0, 1, 1, 0];
521+
julia> y = [0, 1, 0, 1, 1, 1];
522522
523-
julia> y = [1, 0, 0, 1, 0]; # one false negative data point
523+
julia> ŷ_fp = [1, 1, 1, 1, 1, 1]; # 2 false positive -> 2 wrong predictions
524524
525-
julia> Flux.tversky_loss(ŷ, y)
526-
0.18918918918918926
525+
julia> ŷ_fnp = [1, 1, 0, 1, 1, 0]; # 1 false negative, 1 false positive -> 2 wrong predictions
527526
528-
julia> y = [1, 1, 1, 1, 0]; # No false negatives, but a false positive
527+
julia> Flux.tversky_loss(ŷ_fnp, y)
528+
0.19999999999999996
529529
530-
julia> Flux.tversky_loss(, y) # loss is smaller as more weight given to the false negatives
531-
0.06976744186046513
530+
julia> Flux.tversky_loss(ŷ_fp, y) # should be smaller than tversky_loss(ŷ_fnp, y), as FN is given more weight
531+
0.1071428571428571
532532
```
533533
"""
534534
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
535535
_check_sizes(ŷ, y)
536536
#TODO add agg
537537
num = sum(y .* ŷ) + 1
538-
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
538+
den = sum(y .*+ (1 - β) * (1 .- y) .*+ β * y .* (1 .- ŷ)) + 1
539539
1 - num / den
540540
end
541541

0 commit comments

Comments
 (0)