Skip to content

Commit 63e4d98

Browse files
Applied the suggestions
Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent 1987693 commit 63e4d98

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/losses/functions.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ true
458458
See also: [`Losses.focal_loss`](@ref) for multi-class setting
459459
460460
"""
461-
function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
461+
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
462462
=.+ ϵ
463463
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
464464
ce = -log.(p_t)
@@ -501,11 +501,10 @@ true
501501
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
502502
503503
"""
504-
function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
504+
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
505505
=.+ ϵ
506506
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
507507
end
508508
```@meta
509509
DocTestFilters = nothing
510510
```
511-

0 commit comments

Comments
 (0)