You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1636: Add warnings for mismatched sizes in losses r=mcabbott a=mcabbott
Closes#1599, I think, by making loss functions give a warning if the sizes don't match:
```julia
julia> mse([1,0], [1 0 0])
┌ Error: size mismatch in loss function! In future this will be an error; in Flux 0.12 broadcasting acceps some mismatches
│ summary(ŷ) = "2-element Vector{Int64}"
│ summary(y) = "1×3 Matrix{Int64}"
└ @ Flux.Losses ~/.julia/dev/Flux/src/losses/utils.jl:29
0.5
julia> @Btime gradient(sum∘mse, $(rand(10,100)), $(rand(10,100)));
19.709 μs (130 allocations: 51.25 KiB)
19.625 μs (130 allocations: 51.25 KiB)
```
Appears to have no effect on speed, although Zygote is weird and maybe someone has a better test of that.
Edit -- closes#1522, too.
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@warn"Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results"summary(ŷ) summary(y) maxlog=3 _id=hash(size(y))
31
+
end
32
+
end
33
+
end
34
+
_check_sizes(ŷ, y) =nothing# pass-through, for constant label e.g. y = 1
0 commit comments