Skip to content

Commit 6d03787

Browse files
fix warning
1 parent d7b088a commit 6d03787

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/layers/basic.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,8 @@ end
119119

120120
@functor Dense
121121

122-
function (a::Dense)(x::AbstractArray)
123-
if eltype(a.W) != eltype(x)
124-
@warn "Element types of input and weights differ." W=eltype(a.W) x=eltype(x) maxlog=1
125-
end
122+
function (a::Dense)(x::AbstractVecOrMat)
123+
eltype(a.W) == eltype(x) || _dense_typewarn(a, x)
126124
W, b, σ = a.W, a.b, a.σ
127125
# reshape to handle dims > 1 as batch dimensions
128126
sz = size(x)
@@ -131,6 +129,9 @@ function (a::Dense)(x::AbstractArray)
131129
return reshape(x, :, sz[2:end]...)
132130
end
133131

132+
_dense_typewarn(d, x) = @warn "Element types don't match for layer $d, this will be slow." typeof(d.W) typeof(x) maxlog=1
133+
Zygote.@nograd _dense_typewarn
134+
134135
function Base.show(io::IO, l::Dense)
135136
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
136137
l.σ == identity || print(io, ", ", l.σ)

0 commit comments

Comments
 (0)