Skip to content

Commit 9e5851d

Browse files
Avoid broadcast-related type instabilities with huber_loss (#2306)
* Avoid broadcast-related type instabilities with huber_loss * Add test case * Fix function name * Fix test
1 parent bf9da7f commit 9e5851d

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

src/losses/functions.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
7272
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
7373
end
7474

75+
function _huber_metric(abs_error, δ)
76+
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
77+
temp = Zygote.ignore_derivatives(abs_error .< δ)
78+
x = ofeltype(abs_error, 0.5)
79+
((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
80+
end
81+
7582
"""
7683
huber_loss(ŷ, y; delta = 1, agg = mean)
7784
@@ -94,17 +101,14 @@ julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| >
94101
0.003750000000000005
95102
```
96103
"""
97-
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
98-
delta_tmp = _greek_ascii_depwarn=> delta, :huber_loss, "δ" => "delta")
99-
δ = ofeltype(ŷ, delta_tmp)
100-
_check_sizes(ŷ, y)
101-
abs_error = abs.(ŷ .- y)
102-
#TODO: remove ignore_derivatives when Zygote can handle this function with CuArrays
103-
temp = Zygote.ignore_derivatives(abs_error .< δ)
104-
x = ofeltype(ŷ, 0.5)
105-
agg(((abs_error .^ 2) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp))
106-
end
104+
function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing)
105+
delta_tmp = _greek_ascii_depwarn=> delta, :huber_loss, "δ" => "delta")
106+
δ = ofeltype(ŷ, delta_tmp)
107+
_check_sizes(ŷ, y)
108+
abs_error = abs.(ŷ .- y)
107109

110+
agg(_huber_metric.(abs_error, δ))
111+
end
108112
"""
109113
label_smoothing(y::Union{Number, AbstractArray}, α; dims::Int=1)
110114

test/ext_metal/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,15 @@ include("test_utils.jl")
1111
@testset "Basic" begin
1212
include("basic.jl")
1313
end
14+
15+
@testset "Huber Loss test" begin
16+
X = Flux.gpu(Float32[0,1])
17+
Y = Flux.gpu(Float32[1,0])
18+
19+
grad = Flux.gradient(X, Y) do a,b
20+
Flux.Losses.huber_loss(a,b)
21+
end
22+
23+
@test Flux.cpu(grad[1]) == [-0.5, 0.5]
24+
@test Flux.cpu(grad[2]) == [0.5, -0.5]
25+
end

0 commit comments

Comments
 (0)