Skip to content

Commit 544c547

Browse files
author
Dhairya Gandhi
committed
cleanup
1 parent 02ea511 commit 544c547

File tree

1 file changed

+47
-43
lines changed

1 file changed

+47
-43
lines changed

src/losses/functions.jl

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ DocTestFilters = r"[0-9\.]+f0"
44
```
55

66
"""
7-
mae(ŷ, y; agg=mean)
7+
mae(ŷ, y; agg = mean)
88
99
Return the loss corresponding to mean absolute error:
1010
@@ -18,14 +18,14 @@ julia> Flux.mae(y_model, 1:3)
1818
0.10000000000000009
1919
```
2020
"""
21-
mae(ŷ, y; agg=mean) = agg(abs.(ŷ .- y))
21+
mae(ŷ, y; agg = mean) = agg(abs.(ŷ .- y))
2222

2323
"""
24-
mse(ŷ, y; agg=mean)
24+
mse(ŷ, y; agg = mean)
2525
2626
Return the loss corresponding to mean square error:
2727
28-
agg((ŷ .- y).^2)
28+
agg((ŷ .- y) .^ 2)
2929
3030
See also: [`mae`](@ref), [`msle`](@ref), [`crossentropy`](@ref).
3131
@@ -39,14 +39,14 @@ julia> Flux.mse(y_model, y_true)
3939
0.010000000000000018
4040
```
4141
"""
42-
mse(ŷ, y; agg=mean) = agg((ŷ .- y).^2)
42+
mse(ŷ, y; agg = mean) = agg((ŷ .- y) .^ 2)
4343

4444
"""
45-
msle(ŷ, y; agg=mean, ϵ=eps(ŷ))
45+
msle(ŷ, y; agg = mean, ϵ = eps(ŷ))
4646
4747
The loss corresponding to mean squared logarithmic errors, calculated as
4848
49-
agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)).^2)
49+
agg((log.(ŷ .+ ϵ) .- log.(y .+ ϵ)) .^ 2)
5050
5151
The `ϵ` term provides numerical stability.
5252
Penalizes an under-estimation more than an over-estimatation.
@@ -60,10 +60,11 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3)
6060
0.011100831f0
6161
```
6262
"""
63-
msle(ŷ, y; agg=mean, ϵ=epseltype(ŷ)) = agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))).^2)
63+
msle(ŷ, y; agg = mean, ϵ = epseltype(ŷ)) =
64+
agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 )
6465

6566
"""
66-
huber_loss(ŷ, y; δ=1, agg=mean)
67+
huber_loss(ŷ, y; δ = 1, agg = mean)
6768
6869
Return the mean of the [Huber loss](https://en.wikipedia.org/wiki/Huber_loss)
6970
given the prediction `ŷ` and true values `y`.
@@ -72,12 +73,12 @@ given the prediction `ŷ` and true values `y`.
7273
Huber loss = |
7374
| δ * (|ŷ - y| - 0.5 * δ), otherwise
7475
"""
75-
function huber_loss(ŷ, y; agg=mean, δ=ofeltype(ŷ, 1))
76+
function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1))
7677
abs_error = abs.(ŷ .- y)
7778
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
7879
temp = Zygote.dropgrad(abs_error .< δ)
7980
x = ofeltype(ŷ, 0.5)
80-
agg(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp))
81+
agg(((abs_error .^ 2) .* temp) .* x .+ δ * (abs_error .- x * δ) .* (1 .- temp))
8182
end
8283

8384
"""
@@ -131,7 +132,7 @@ julia> Flux.crossentropy(y_dis, y) > Flux.crossentropy(y_dis, y_smoothed)
131132
true
132133
```
133134
"""
134-
function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int=1)
135+
function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int = 1)
135136
if !(0 < α < 1)
136137
throw(ArgumentError("α must be between 0 and 1"))
137138
end
@@ -146,7 +147,7 @@ function label_smoothing(y::Union{AbstractArray,Number}, α::Number; dims::Int=1
146147
end
147148

148149
"""
149-
crossentropy(ŷ, y; dims=1, ϵ=eps(ŷ), agg=mean)
150+
crossentropy(ŷ, y; dims = 1, ϵ = eps(ŷ), agg = mean)
150151
151152
Return the cross entropy between the given probability distributions;
152153
calculated as
@@ -201,12 +202,12 @@ julia> Flux.crossentropy(y_model, y_smooth)
201202
1.5776052f0
202203
```
203204
"""
204-
function crossentropy(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
205-
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims=dims))
205+
function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
206+
agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims))
206207
end
207208

208209
"""
209-
logitcrossentropy(ŷ, y; dims=1, agg=mean)
210+
logitcrossentropy(ŷ, y; dims = 1, agg = mean)
210211
211212
Return the cross entropy calculated by
212213
@@ -239,16 +240,16 @@ julia> Flux.crossentropy(softmax(y_model), y_label)
239240
1.5791197f0
240241
```
241242
"""
242-
function logitcrossentropy(ŷ, y; dims=1, agg=mean)
243-
agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims))
243+
function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
244+
agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims))
244245
end
245246

246247
"""
247-
binarycrossentropy(ŷ, y; agg=mean, ϵ=eps(ŷ))
248+
binarycrossentropy(ŷ, y; agg = mean, ϵ = eps(ŷ))
248249
249250
Return the binary cross-entropy loss, computed as
250251
251-
agg(@.(-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)))
252+
agg(@.(-y * log(ŷ + ϵ) - (1 - y) * log(1 - ŷ + ϵ)))
252253
253254
Where typically, the prediction `ŷ` is given by the output of a [`sigmoid`](@ref) activation.
254255
The `ϵ` term is included to avoid infinity. Using [`logitbinarycrossentropy`](@ref) is recomended
@@ -287,14 +288,14 @@ julia> Flux.crossentropy(y_prob, y_hot)
287288
0.43989f0
288289
```
289290
"""
290-
function binarycrossentropy(ŷ, y; agg=mean, ϵ=epseltype(ŷ))
291-
agg(@.(-xlogy(y, ŷ+ϵ) - xlogy(1-y, 1-+ϵ)))
291+
function binarycrossentropy(ŷ, y; agg = mean, ϵ = epseltype(ŷ))
292+
agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 -+ ϵ)))
292293
end
293294
# Re-definition to fix interaction with CuArrays.
294-
# CUDA.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
295+
# CUDA.@cufunc binarycrossentropy(ŷ, y; ϵ = eps(ŷ)) = -y * log(ŷ + ϵ) - (1 - y) * log(1 - ŷ + ϵ)
295296

296297
"""
297-
logitbinarycrossentropy(ŷ, y; agg=mean)
298+
logitbinarycrossentropy(ŷ, y; agg = mean)
298299
299300
Mathematically equivalent to
300301
[`binarycrossentropy(σ(ŷ), y)`](@ref) but is more numerically stable.
@@ -318,15 +319,15 @@ julia> Flux.binarycrossentropy(sigmoid.(y_model), y_bin)
318319
0.16083185f0
319320
```
320321
"""
321-
function logitbinarycrossentropy(ŷ, y; agg=mean)
322-
agg(@.((1-y)*- logσ(ŷ)))
322+
function logitbinarycrossentropy(ŷ, y; agg = mean)
323+
agg(@.((1 - y) * - logσ(ŷ)))
323324
end
324325
# Re-definition to fix interaction with CuArrays.
325-
# CUDA.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y)*ŷ - logσ(ŷ)
326+
# CUDA.@cufunc logitbinarycrossentropy(ŷ, y) = (1 - y) * ŷ - logσ(ŷ)
326327

327328

328329
"""
329-
kldivergence(ŷ, y; agg=mean, ϵ=eps(ŷ))
330+
kldivergence(ŷ, y; agg = mean, ϵ = eps(ŷ))
330331
331332
Return the
332333
[Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence)
@@ -350,19 +351,19 @@ julia> p2 = fill(0.5, 2, 2)
350351
julia> Flux.kldivergence(p2, p1) ≈ log(2)
351352
true
352353
353-
julia> Flux.kldivergence(p2, p1; agg=sum) ≈ 2log(2)
354+
julia> Flux.kldivergence(p2, p1; agg = sum) ≈ 2log(2)
354355
true
355356
356-
julia> Flux.kldivergence(p2, p2; ϵ=0) # about -2e-16 with the regulator
357+
julia> Flux.kldivergence(p2, p2; ϵ = 0) # about -2e-16 with the regulator
357358
0.0
358359
359-
julia> Flux.kldivergence(p1, p2; ϵ=0) # about 17.3 with the regulator
360+
julia> Flux.kldivergence(p1, p2; ϵ = 0) # about 17.3 with the regulator
360361
Inf
361362
```
362363
"""
363-
function kldivergence(ŷ, y; dims=1, agg=mean, ϵ=epseltype(ŷ))
364-
entropy = agg(sum(xlogx.(y), dims=dims))
365-
cross_entropy = crossentropy(ŷ, y; dims=dims, agg=agg, ϵ=ϵ)
364+
function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
365+
entropy = agg(sum(xlogx.(y), dims = dims))
366+
cross_entropy = crossentropy(ŷ, y; dims = dims, agg = agg, ϵ = ϵ)
366367
return entropy + cross_entropy
367368
end
368369

@@ -374,18 +375,19 @@ end
374375
375376
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
376377
"""
377-
poisson_loss(ŷ, y; agg=mean) = agg(ŷ .- xlogy.(y, ŷ))
378+
poisson_loss(ŷ, y; agg = mean) = agg(ŷ .- xlogy.(y, ŷ))
378379

379380
"""
380-
hinge_loss(ŷ, y; agg=mean)
381+
hinge_loss(ŷ, y; agg = mean)
381382
382383
Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
383384
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
384385
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
385386
386387
See also: [`squared_hinge_loss`](@ref)
387388
"""
388-
hinge_loss(ŷ, y; agg=mean) = agg(max.(0, 1 .-.* y))
389+
hinge_loss(ŷ, y; agg = mean) =
390+
agg(max.(0, 1 .-.* y))
389391

390392
"""
391393
squared_hinge_loss(ŷ, y)
@@ -395,10 +397,11 @@ Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y
395397
396398
See also: [`hinge_loss`](@ref)
397399
"""
398-
squared_hinge_loss(ŷ, y; agg=mean) = agg((max.(0, 1 .-.* y)).^2)
400+
squared_hinge_loss(ŷ, y; agg = mean) =
401+
agg((max.(0, 1 .-.* y)) .^ 2)
399402

400403
"""
401-
dice_coeff_loss(ŷ, y; smooth=1)
404+
dice_coeff_loss(ŷ, y; smooth = 1)
402405
403406
Return a loss based on the dice coefficient.
404407
Used in the [V-Net](https://arxiv.org/abs/1606.04797) image segmentation
@@ -407,21 +410,22 @@ Similar to the F1_score. Calculated as:
407410
408411
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)
409412
"""
410-
dice_coeff_loss(ŷ, y; smooth=ofeltype(ŷ, 1.0)) = 1 - (2*sum(y .* ŷ) + smooth) / (sum(y.^2) + sum(ŷ.^2) + smooth) #TODO agg
413+
dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0)) =
414+
1 - (2 * sum(y .* ŷ) + smooth) / (sum(y .^ 2) + sum(ŷ .^ 2) + smooth) #TODO agg
411415

412416
"""
413-
tversky_loss(ŷ, y; β=0.7)
417+
tversky_loss(ŷ, y; β = 0.7)
414418
415419
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
416420
Used with imbalanced data to give more weight to false negatives.
417421
Larger β weigh recall more than precision (by placing more emphasis on false negatives)
418422
Calculated as:
419423
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
420424
"""
421-
function tversky_loss(ŷ, y; β=ofeltype(ŷ, 0.7))
425+
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
422426
#TODO add agg
423427
num = sum(y .* ŷ) + 1
424-
den = sum(y .*+ β*(1 .- y) .*+ (1 - β)*y .* (1 .- ŷ)) + 1
428+
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
425429
1 - num / den
426430
end
427431

0 commit comments

Comments
 (0)