@@ -511,31 +511,31 @@ end
511
511
512
512
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
513
513
Used with imbalanced data to give more weight to false negatives.
514
- Larger β weigh recall more than precision (by placing more emphasis on false negatives)
514
+ Larger β weigh recall more than precision (by placing more emphasis on false negatives).
515
515
Calculated as:
516
516
517
- 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β *(1 .- y) .* ŷ + (1 - β) *y .* (1 .- ŷ)) + 1)
517
+ 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β) *(1 .- y) .* ŷ + β *y .* (1 .- ŷ)) + 1)
518
518
519
519
# Example
520
520
```jldoctest
521
- julia> ŷ = [1, 0, 1, 1, 0 ];
521
+ julia> y = [0, 1, 0, 1, 1, 1 ];
522
522
523
- julia> y = [1, 0, 0 , 1, 0 ]; # one false negative data point
523
+ julia> ŷ_fp = [1, 1, 1 , 1, 1, 1 ]; # 2 false positive -> 2 wrong predictions
524
524
525
- julia> Flux.tversky_loss(ŷ, y)
526
- 0.18918918918918926
525
+ julia> ŷ_fnp = [1, 1, 0, 1, 1, 0]; # 1 false negative, 1 false positive -> 2 wrong predictions
527
526
528
- julia> y = [1, 1, 1, 1, 0]; # No false negatives, but a false positive
527
+ julia> Flux.tversky_loss(ŷ_fnp, y)
528
+ 0.19999999999999996
529
529
530
- julia> Flux.tversky_loss(ŷ , y) # loss is smaller as more weight given to the false negatives
531
- 0.06976744186046513
530
+ julia> Flux.tversky_loss(ŷ_fp , y) # should be smaller than tversky_loss(ŷ_fnp, y), as FN is given more weight
531
+ 0.1071428571428571
532
532
```
533
533
"""
534
534
function tversky_loss (ŷ, y; β = ofeltype (ŷ, 0.7 ))
535
535
_check_sizes (ŷ, y)
536
536
# TODO add agg
537
537
num = sum (y .* ŷ) + 1
538
- den = sum (y .* ŷ + β * (1 .- y) .* ŷ + ( 1 - β) * y .* (1 .- ŷ)) + 1
538
+ den = sum (y .* ŷ + ( 1 - β) * (1 .- y) .* ŷ + β * y .* (1 .- ŷ)) + 1
539
539
1 - num / den
540
540
end
541
541
0 commit comments