Skip to content

Commit 5ce5481

Browse files
shikhargoswamidarsnack
authored andcommitted
Done!
1 parent 63747a6 commit 5ce5481

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/losses/functions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ true
458458
See also: [`Losses.focal_loss`](@ref) for multi-class setting
459459
460460
"""
461-
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
461+
function binary_focal_loss(ŷ, y; agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
462462
=.+ ϵ
463463
p_t = y .*+ (1 .- y) .* (1 .- ŷ)
464464
ce = -log.(p_t)
@@ -501,7 +501,7 @@ true
501501
See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels
502502
503503
"""
504-
function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ))
504+
function focal_loss(ŷ, y; dims=1, agg=mean, γ=ofeltype(ŷ, 2.0), ϵ=epseltype(ŷ))
505505
=.+ ϵ
506506
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims))
507507
end

0 commit comments

Comments
 (0)