Skip to content

Commit 10be1fb

Browse files
bors[bot]Michael Abbott
andauthored
Merge #1463
1463: Improve docs for `crossentropy` & friends r=DhairyaLGandhi a=mcabbott Also restores export of binarycrossentropy, logitbinarycrossentropy from Losses module, which was commented out for some long-lost deprecation reason, I think? Co-authored-by: Michael Abbott <me@escbook>
2 parents 9f09b78 + 741948a commit 10be1fb

File tree

2 files changed

+208
-39
lines changed

2 files changed

+208
-39
lines changed

src/losses/Losses.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import Base.Broadcast: broadcasted
1111
export mse, mae, msle,
1212
label_smoothing,
1313
crossentropy, logitcrossentropy,
14-
# binarycrossentropy, logitbinarycrossentropy # export only after end deprecation
14+
binarycrossentropy, logitbinarycrossentropy,
1515
kldivergence,
1616
huber_loss,
1717
tversky_loss,

src/losses/functions.jl

Lines changed: 207 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
# In this file, doctests which differ in the printed Float32 values won't fail
2+
```@meta
3+
DocTestFilters = r"[0-9\.]+f0"
4+
```
5+
16
"""
27
mae(ŷ, y; agg=mean)
38
49
Return the loss corresponding to mean absolute error:
510
611
agg(abs.(ŷ .- y))
12+
13+
# Example
14+
```jldoctest
15+
julia> y_model = [1.1, 1.9, 3.1];
16+
17+
julia> Flux.mae(y_model, 1:3)
18+
0.10000000000000009
19+
```
720
"""
821
mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))
922

@@ -13,6 +26,18 @@ mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))
1326
Return the loss corresponding to mean square error:
1427
1528
agg((ŷ .- y).^2)
29+
30+
See also: [`mae`](@ref), [`msle`](@ref), [`crossentropy`](@ref).
31+
32+
# Example
33+
```jldoctest
34+
julia> y_model = [1.1, 1.9, 3.1];
35+
36+
julia> y_true = 1:3;
37+
38+
julia> Flux.mse(y_model, y_true)
39+
0.010000000000000018
40+
```
1641
"""
1742
mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2)
1843

@@ -25,6 +50,15 @@ The loss corresponding to mean squared logarithmic errors, calculated as
2550
2651
The `ϵ` term provides numerical stability.
2752
Penalizes an under-estimation more than an over-estimatation.
53+
54+
# Example
55+
```jldoctest
56+
julia> Flux.msle(Float32[1.1, 2.2, 3.3], 1:3)
57+
0.009084041f0
58+
59+
julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
60+
0.011100831f0
61+
```
2862
"""
2963
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))).^2)
3064

@@ -68,16 +102,34 @@ value of α larger the smoothing of `y`.
68102
`dims` denotes the one-hot dimension, unless `dims=0` which denotes the application
69103
of label smoothing to binary distributions encoded in a single number.
70104
71-
Usage example:
72-
73-
sf = 0.1
74-
y = onehotbatch([1, 1, 1, 0, 0], 0:1)
75-
y_smoothed = label_smoothing(ya, 2sf)
76-
y_sim = y .* (1-2sf) .+ sf
77-
y_dis = copy(y_sim)
78-
y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:]
79-
@assert crossentropy(y_sim, y) < crossentropy(y_sim, y_smoothed)
80-
@assert crossentropy(y_dis, y) > crossentropy(y_dis, y_smoothed)
105+
# Example
106+
```jldoctest
107+
julia> y = Flux.onehotbatch([1, 1, 1, 0, 1, 0], 0:1)
108+
2×6 Flux.OneHotArray{UInt32,2,1,2,Array{UInt32,1}}:
109+
0 0 0 1 0 1
110+
1 1 1 0 1 0
111+
112+
julia> y_smoothed = Flux.label_smoothing(y, 0.2f0)
113+
2×6 Array{Float32,2}:
114+
0.1 0.1 0.1 0.9 0.1 0.9
115+
0.9 0.9 0.9 0.1 0.9 0.1
116+
117+
julia> y_sim = softmax(y .* log(2f0))
118+
2×6 Array{Float32,2}:
119+
0.333333 0.333333 0.333333 0.666667 0.333333 0.666667
120+
0.666667 0.666667 0.666667 0.333333 0.666667 0.333333
121+
122+
julia> y_dis = vcat(y_sim[2,:]', y_sim[1,:]')
123+
2×6 Array{Float32,2}:
124+
0.666667 0.666667 0.666667 0.333333 0.666667 0.333333
125+
0.333333 0.333333 0.333333 0.666667 0.333333 0.666667
126+
127+
julia> Flux.crossentropy(y_sim, y) < Flux.crossentropy(y_sim, y_smoothed)
128+
true
129+
130+
julia> Flux.crossentropy(y_dis, y) > Flux.crossentropy(y_dis, y_smoothed)
131+
true
132+
```
81133
"""
82134
function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int=1)
83135
if !(0 < α < 1)
@@ -99,22 +151,55 @@ end
99151
Return the cross entropy between the given probability distributions;
100152
calculated as
101153
102-
agg(-sum(y .* log.(ŷ .+ ϵ); dims=dims))
154+
agg(-sum(y .* log.(ŷ .+ ϵ); dims))
103155
104156
Cross entropy is typically used as a loss in multi-class classification,
105157
in which case the labels `y` are given in a one-hot format.
106158
`dims` specifies the dimension (or the dimensions) containing the class probabilities.
107159
The prediction `ŷ` is supposed to sum to one across `dims`,
108160
as would be the case with the output of a [`softmax`](@ref) operation.
109161
162+
For numerical stability, it is recommended to use [`logitcrossentropy`](@ref)
163+
rather than `softmax` followed by `crossentropy` .
164+
110165
Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before
111166
computing the loss.
112167
113-
Use of [`logitcrossentropy`](@ref) is recomended over `crossentropy` for
114-
numerical stability.
115-
116-
See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref),
117-
[`label_smoothing`](@ref)
168+
See also: [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref).
169+
170+
# Example
171+
```jldoctest
172+
julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2)
173+
3×5 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
174+
1 0 0 0 1
175+
0 1 0 1 0
176+
0 0 1 0 0
177+
178+
julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
179+
3×5 Array{Float32,2}:
180+
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
181+
0.244728 0.244728 0.244728 0.244728 0.244728
182+
0.665241 0.665241 0.665241 0.665241 0.665241
183+
184+
julia> sum(y_model; dims=1)
185+
1×5 Array{Float32,2}:
186+
1.0 1.0 1.0 1.0 1.0
187+
188+
julia> Flux.crossentropy(y_model, y_label)
189+
1.6076053f0
190+
191+
julia> 5 * ans ≈ Flux.crossentropy(y_model, y_label; agg=sum)
192+
true
193+
194+
julia> y_smooth = Flux.label_smoothing(y_label, 0.15f0)
195+
3×5 Array{Float32,2}:
196+
0.9 0.05 0.05 0.05 0.9
197+
0.05 0.9 0.05 0.9 0.05
198+
0.05 0.05 0.9 0.05 0.05
199+
200+
julia> Flux.crossentropy(y_model, y_smooth)
201+
1.5776052f0
202+
```
118203
"""
119204
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
120205
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
@@ -123,19 +208,36 @@ end
123208
"""
124209
logitcrossentropy(ŷ, y; dims=1, agg=mean)
125210
126-
Return the crossentropy computed after a [`logsoftmax`](@ref) operation;
127-
calculated as
211+
Return the cross entropy calculated by
128212
129-
agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims))
213+
agg(-sum(y .* logsoftmax(ŷ; dims); dims))
130214
131-
Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before
132-
computing the loss.
215+
This is mathematically equivalent to `crossentropy(softmax(ŷ), y)`,
216+
but is more numerically stable than using functions [`crossentropy`](@ref)
217+
and [`softmax`](@ref) separately.
218+
219+
See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref).
133220
134-
`logitcrossentropy(ŷ, y)` is mathematically equivalent to
135-
[`crossentropy(softmax(ŷ), y)`](@ref) but it is more numerically stable.
221+
# Example
222+
```jldoctest
223+
julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c')
224+
3×7 Flux.OneHotArray{UInt32,3,1,2,Array{UInt32,1}}:
225+
1 0 0 1 0 1 1
226+
0 1 0 0 1 0 0
227+
0 0 1 0 0 0 0
136228
229+
julia> y_model = reshape(vcat(-9:0, 0:9, 7.5f0), 3, 7)
230+
3×7 Array{Float32,2}:
231+
-9.0 -6.0 -3.0 0.0 2.0 5.0 8.0
232+
-8.0 -5.0 -2.0 0.0 3.0 6.0 9.0
233+
-7.0 -4.0 -1.0 1.0 4.0 7.0 7.5
137234
138-
See also: [`crossentropy`](@ref), [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref)
235+
julia> Flux.logitcrossentropy(y_model, y_label)
236+
1.5791205f0
237+
238+
julia> Flux.crossentropy(softmax(y_model), y_label)
239+
1.5791197f0
240+
```
139241
"""
140242
function logitcrossentropy(ŷ, y; dims=1, agg=mean)
141243
agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims))
@@ -148,17 +250,42 @@ Return the binary cross-entropy loss, computed as
148250
149251
agg(@.(-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)))
150252
151-
The `ϵ` term provides numerical stability.
152-
153-
Typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
253+
Where typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
254+
The `ϵ` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended
255+
over `binarycrossentropy` for numerical stability.
154256
155257
Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
156258
computing the loss.
157259
158-
Use of `logitbinarycrossentropy` is recomended over `binarycrossentropy` for numerical stability.
260+
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref).
261+
262+
# Examples
263+
```jldoctest
264+
julia> y_bin = Bool[1,0,1]
265+
3-element Array{Bool,1}:
266+
1
267+
0
268+
1
269+
270+
julia> y_prob = softmax(reshape(vcat(1:3, 3:5), 2, 3) .* 1f0)
271+
2×3 Array{Float32,2}:
272+
0.268941 0.5 0.268941
273+
0.731059 0.5 0.731059
274+
275+
julia> Flux.binarycrossentropy(y_prob[2,:], y_bin)
276+
0.43989f0
277+
278+
julia> all(p -> 0<p<1, y_prob[2,:]) # else DomainError
279+
true
159280
160-
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref), [`logitbinarycrossentropy`](@ref),
161-
[`label_smoothing`](@ref)
281+
julia> y_hot = Flux.onehotbatch(y_bin, 0:1)
282+
2×3 Flux.OneHotArray{UInt32,2,1,2,Array{UInt32,1}}:
283+
0 1 0
284+
1 0 1
285+
286+
julia> Flux.crossentropy(y_prob, y_hot)
287+
0.43989f0
288+
```
162289
"""
163290
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
164291
agg(@.(-xlogy(y, ŷ+ϵ) - xlogy(1-y, 1-+ϵ)))
@@ -172,10 +299,24 @@ end
172299
Mathematically equivalent to
173300
[`binarycrossentropy(σ(ŷ), y)`](@ref) but is more numerically stable.
174301
175-
Use [`label_smoothing`](@ref) to smooth the `y` value as preprocessing before
176-
computing the loss.
302+
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref).
303+
304+
# Examples
305+
```jldoctest
306+
julia> y_bin = Bool[1,0,1];
307+
308+
julia> y_model = Float32[2, -1, pi]
309+
3-element Array{Float32,1}:
310+
2.0
311+
-1.0
312+
3.1415927
177313
178-
See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref), [`binarycrossentropy`](@ref), [`label_smoothing`](@ref)
314+
julia> Flux.logitbinarycrossentropy(y_model, y_bin)
315+
0.160832f0
316+
317+
julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin)
318+
0.16083185f0
319+
```
179320
"""
180321
function logitbinarycrossentropy(ŷ, y; agg=mean)
181322
agg(@.((1-y)*- logσ(ŷ)))
@@ -185,16 +326,39 @@ end
185326

186327

187328
"""
188-
kldivergence(ŷ, y; agg=mean)
329+
kldivergence(ŷ, y; agg=mean, ϵ=eps(ŷ))
189330
190331
Return the
191332
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
192333
between the given probability distributions.
193334
194-
KL divergence is a measure of how much one probability distribution is different
195-
from the other.
196-
It is always non-negative and zero only when both the distributions are equal
197-
everywhere.
335+
The KL divergence is a measure of how much one probability distribution is different
336+
from the other. It is always non-negative, and zero only when both the distributions are equal.
337+
338+
# Example
339+
```jldoctest
340+
julia> p1 = [1 0; 0 1]
341+
2×2 Array{Int64,2}:
342+
1 0
343+
0 1
344+
345+
julia> p2 = fill(0.5, 2, 2)
346+
2×2 Array{Float64,2}:
347+
0.5 0.5
348+
0.5 0.5
349+
350+
julia> Flux.kldivergence(p2, p1) ≈ log(2)
351+
true
352+
353+
julia> Flux.kldivergence(p2, p1; agg=sum) ≈ 2log(2)
354+
true
355+
356+
julia> Flux.kldivergence(p2, p2; ϵ=0) # about -2e-16 with the regulator
357+
0.0
358+
359+
julia> Flux.kldivergence(p1, p2; ϵ=0) # about 17.3 with the regulator
360+
Inf
361+
```
198362
"""
199363
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
200364
entropy = agg(sum(xlogx.(y), dims=dims))
@@ -260,3 +424,8 @@ function tversky_loss(ŷ, y; β=ofeltype(ŷ, 0.7))
260424
den = sum(y .*+ β*(1 .- y) .*+ (1 - β)*y .* (1 .- ŷ)) + 1
261425
1 - num / den
262426
end
427+
428+
429+
```@meta
430+
DocTestFilters = nothing
431+
```

0 commit comments

Comments
 (0)