Skip to content

Commit 9326702

Browse files
committed
Teach Flux.Losses.mse() to use conjugates
This improves the calculation of error for complex-valued targets
1 parent 7af4f4c commit 9326702

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/losses/functions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ julia> Flux.mse(y_model, y_true)
4444
"""
4545
function mse(ŷ, y; agg = mean)
4646
_check_sizes(ŷ, y)
47-
agg((ŷ .- y) .^ 2)
47+
error =.- y
48+
real(agg(error .* conj(error)))
4849
end
4950

5051
"""

0 commit comments

Comments
 (0)