@@ -72,6 +72,13 @@ function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing)
72
72
agg ((log .((ŷ .+ ϵ) ./ (y .+ ϵ))) .^ 2 )
73
73
end
74
74
75
+ function _huber_metric (abs_error, δ)
76
+ # TODO : remove ignore_derivatives when Zygote can handle this function with CuArrays
77
+ temp = Zygote. ignore_derivatives (abs_error .< δ)
78
+ x = ofeltype (abs_error, 0.5 )
79
+ ((abs_error * abs_error) * temp) * x + δ * (abs_error - x * δ) * (1 - temp)
80
+ end
81
+
75
82
"""
76
83
huber_loss(ŷ, y; delta = 1, agg = mean)
77
84
@@ -94,17 +101,14 @@ julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| >
94
101
0.003750000000000005
95
102
```
96
103
"""
97
- function huber_loss (ŷ, y; agg = mean, delta:: Real = 1 , δ = nothing )
98
- delta_tmp = _greek_ascii_depwarn (δ => delta, :huber_loss , " δ" => " delta" )
99
- δ = ofeltype (ŷ, delta_tmp)
100
- _check_sizes (ŷ, y)
101
- abs_error = abs .(ŷ .- y)
102
- # TODO : remove ignore_derivatives when Zygote can handle this function with CuArrays
103
- temp = Zygote. ignore_derivatives (abs_error .< δ)
104
- x = ofeltype (ŷ, 0.5 )
105
- agg (((abs_error .^ 2 ) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp))
106
- end
104
+ function huber_loss (ŷ, y; agg = mean, delta:: Real = 1 , δ = nothing )
105
+ delta_tmp = _greek_ascii_depwarn (δ => delta, :huber_loss , " δ" => " delta" )
106
+ δ = ofeltype (ŷ, delta_tmp)
107
+ _check_sizes (ŷ, y)
108
+ abs_error = abs .(ŷ .- y)
107
109
110
+ agg (_huber_metric .(abs_error, δ))
111
+ end
108
112
"""
109
113
label_smoothing(y::Union{Number, AbstractArray}, α; dims::Int=1)
110
114
0 commit comments