Skip to content

Commit 8d7b27f

Browse files
committed
_match_eltype
1 parent 7997174 commit 8d7b27f

File tree

9 files changed

+136
-13
lines changed

9 files changed

+136
-13
lines changed

src/layers/basic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ end
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
172-
return σ.(a.weight * x .+ a.bias)
172+
xT = _match_eltype(a, x) # fixes Float64 input, etc.
173+
return σ.(a.weight * xT .+ a.bias)
173174
end
174175

175176
(a::Dense)(x::AbstractArray) =

src/layers/conv.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
197197
function (c::Conv)(x::AbstractArray)
198198
σ = NNlib.fast_act(c.σ, x)
199199
cdims = conv_dims(c, x)
200-
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
200+
xT = _match_eltype(c, x)
201+
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
201202
end
202203

203204
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -330,7 +331,8 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
330331
function (c::ConvTranspose)(x::AbstractArray)
331332
σ = NNlib.fast_act(c.σ, x)
332333
cdims = conv_transpose_dims(c, x)
333-
σ.(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
334+
xT = _match_eltype(c, x)
335+
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
334336
end
335337

336338
function Base.show(io::IO, l::ConvTranspose)
@@ -468,7 +470,8 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
468470
function (c::CrossCor)(x::AbstractArray)
469471
σ = NNlib.fast_act(c.σ, x)
470472
cdims = crosscor_dims(c, x)
471-
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
473+
xT = _match_eltype(c, x)
474+
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
472475
end
473476

474477
function Base.show(io::IO, l::CrossCor)

src/layers/recurrent.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,11 @@ end
200200
RNNCell((in, out)::Pair, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
201201
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
202202

203-
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,I,H,V,T}
203+
function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {F,I,H,V,T}
204204
Wi, Wh, b = m.Wi, m.Wh, m.b
205205
σ = NNlib.fast_act(m.σ, x)
206-
h = σ.(Wi*x .+ Wh*h .+ b)
206+
xT = _match_eltype(m, T, x)
207+
h = σ.(Wi*xT .+ Wh*h .+ b)
207208
return h, reshape_cell_output(h, x)
208209
end
209210

@@ -305,9 +306,10 @@ function LSTMCell((in, out)::Pair;
305306
return cell
306307
end
307308

308-
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
309+
function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
309310
b, o = m.b, size(h, 1)
310-
g = muladd(m.Wi, x, muladd(m.Wh, h, b))
311+
xT = _match_eltype(m, T, x)
312+
g = muladd(m.Wi, xT, muladd(m.Wh, h, b))
311313
input, forget, cell, output = multigate(g, o, Val(4))
312314
c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
313315
h′ = @. sigmoid_fast(output) * tanh_fast(c′)
@@ -376,9 +378,10 @@ end
376378
GRUCell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
377379
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
378380

379-
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
381+
function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,T}
380382
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
381-
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
383+
xT = _match_eltype(m, T, x)
384+
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
382385
r, z = _gru_output(gxs, ghs, bs)
383386
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
384387
h′ = @. (1 - z) *+ z * h
@@ -444,9 +447,10 @@ GRUv3Cell((in, out)::Pair; init = glorot_uniform, initb = zeros32, init_state =
444447
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
445448
init(out, out), init_state(out,1))
446449

447-
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,HH,T}
450+
function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{<:AbstractFloat},OneHotArray}) where {I,H,V,HH,T}
448451
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
449-
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
452+
xT = _match_eltype(m, T, x)
453+
gxs, ghs, bs = multigate(Wi*xT, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
450454
r, z = _gru_output(gxs, ghs, bs)
451455
= tanh_fast.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
452456
h′ = @. (1 - z) *+ z * h

src/layers/stateless.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,47 @@ true
5757
σ = std(x, dims=dims, mean=μ, corrected=false)
5858
return @. (x - μ) /+ ϵ)
5959
end
60+
61+
"""
62+
_match_eltype(layer, ::Type{T}, x)
63+
_match_eltype(layer, x)
64+
65+
This internal function corrects most layer input to match the type of the weights.
66+
The second method uses `T = eltype(layer.weight)`.
67+
68+
It solves a common performance bug: Before, accidentally supplying `Float64` input,
69+
or an activation function which produces `Float64`, would silently run the
70+
entire forward pass in this precision.
71+
"""
72+
_match_eltype(layer, ::Type{T}, x::AbstractArray{T}) where {T} = x
73+
74+
# A common mistake, print a friendly warning, and fix it:
75+
function _match_eltype(layer, ::Type{Float32}, x::AbstractArray{Float64})
76+
# This warning is the only reason this needs to take the layer.
77+
@warn "Layer with Float32 parameters got Float64 input.
78+
The input will be converted, but any earlier layers may be very slow." layer summary(x) maxlog=1
79+
convert(AbstractArray{Float32}, x)
80+
end
81+
82+
# Allow OneHot to reach specialisation of * etc:
83+
_match_eltype(layer, ::Type, x::OneHotLike) = x
84+
85+
# Other floats, and integers, silently fix.
86+
function _match_eltype(layer, ::Type{T}, x::AbstractArray{<:Union{AbstractFloat, Integer}}) where {T}
87+
convert(AbstractArray{T}, x)
88+
end
89+
90+
# Weird types like Nil, Dual, etc, we allow through:
91+
_match_eltype(layer, ::Type, x::AbstractArray) = x
92+
93+
# 2-arg method, for common layers with layer.weight
94+
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x)
95+
96+
# Trivial rule:
97+
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}
98+
_match_eltype(layer, T, x), dx -> (NoTangent(), ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
99+
end
100+
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, x::AbstractArray)
101+
_match_eltype(layer, x), dx -> (ZeroTangent(), NoTangent(), dx) # does not un-thunk dx
102+
end
103+

src/outputsize.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ for (fn, Dims) in ((:conv, DenseConvDims),)
173173
end
174174
end
175175

176+
# Recurrent layers: just convert to the type they like & convert back.
177+
178+
for Cell in [:RNNCell, :LSTMCell, :GRUCell, :GRUv3Cell]
179+
@eval function (m::Recur{<:$Cell})(x::AbstractArray{Nil})
180+
xT = fill!(similar(m.cell.Wi, size(x)), 0)
181+
_, y = m.cell(m.state, xT) # discard the new state
182+
return similar(x, size(y))
183+
end
184+
end
185+
176186

177187
"""
178188
@autosize (size...,) Chain(Layer(_ => 2), Layer(_), ...)
@@ -229,7 +239,6 @@ Limitations:
229239
* While `@autosize (5, 32) Flux.Bilinear(_ => 7)` is OK, something like `Bilinear((_, _) => 7)` will fail.
230240
* While `Scale(_)` and `LayerNorm(_)` are fine (and use the first dimension), `Scale(_,_)` and `LayerNorm(_,_)`
231241
will fail if `size(x,1) != size(x,2)`.
232-
* RNNs won't work: `@autosize (7, 11) LSTM(_ => 5)` fails, because `outputsize(RNN(3=>7), (3,))` also fails, a known issue.
233242
"""
234243
macro autosize(size, model)
235244
Meta.isexpr(size, :tuple) || error("@autosize's first argument must be a tuple, the size of the input")

test/layers/basic.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,23 @@ import Flux: activations
8989
@test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
9090
@test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
9191
end
92+
@testset "type matching" begin
93+
d1 = Dense(2 => 3)
94+
d2 = Dense(d1.weight, false)
95+
x1 = randn(Float32, 2, 4)
96+
@test d1(x1) d2(x1) d1.weight * x1
97+
x2 = Float64.(x1)
98+
@test d1(x2) d2(x2) d1.weight * x2
99+
@test d1(x2) isa Array{Float32} # tests _match_eltype, will print a warning
100+
@test d2(x2) isa Array{Float32}
101+
102+
x3 = rand(-5:5, 2, 4)
103+
@test d1(x3) d2(x3) d1.weight * x3
104+
x4 = rand(Bool, 2, 4)
105+
@test d1(x4) d2(x4) d1.weight * x4
106+
x5 = Flux.onehotbatch(rand(Bool, 5), (true, false))
107+
@test d1(x5) d2(x5) d1.weight * x5
108+
end
92109
end
93110

94111
@testset "Scale" begin

test/layers/conv.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,17 @@ end
286286
end
287287
@test_throws DimensionMismatch fun(rand(2,3,4), rand(6))
288288
end
289+
290+
@testset "type matching" begin
291+
x = rand(Float64, 10,2,5)
292+
xi = rand(-3:3, 10,2,5)
293+
c1 = Conv((3,), 2=>4, relu)
294+
@test @inferred(c1(x)) isa Array{Float32, 3}
295+
@test c1(xi) isa Array{Float32, 3}
296+
297+
c2 = CrossCor((3,), 2=>1, relu)
298+
@test @inferred(c2(x)) isa Array{Float32, 3}
299+
300+
c3 = ConvTranspose((3,), 2=>4, relu)
301+
@test @inferred(c3(x)) isa Array{Float32, 3}
302+
end

test/layers/recurrent.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,27 @@ end
169169
@test size(m(x3)) == (5, 1, 2)
170170
end
171171
end
172+
173+
@testset "type matching" begin
174+
x = rand(Float64, 2, 4)
175+
m1 = RNN(2=>3)
176+
@test m1(x) isa Matrix{Float32} # uses _match_eltype, may print a warning
177+
@test m1.state isa Matrix{Float32}
178+
@test (@inferred m1(x); true)
179+
@test Flux.outputsize(m1, size(x)) == size(m1(x))
180+
181+
m2 = LSTM(2=>3)
182+
@test m2(x) isa Matrix{Float32}
183+
@test (@inferred m2(x); true)
184+
@test Flux.outputsize(m2, size(x)) == size(m2(x))
185+
186+
m3 = GRU(2=>3)
187+
@test m3(x) isa Matrix{Float32}
188+
@test (@inferred m3(x); true)
189+
@test Flux.outputsize(m3, size(x)) == size(m3(x))
190+
191+
m4 = GRUv3(2=>3)
192+
@test m4(x) isa Matrix{Float32}
193+
@test (@inferred m4(x); true)
194+
@test Flux.outputsize(m4, size(x)) == size(m4(x))
195+
end

test/outputsize.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,10 @@ end
257257
# Can't let |> gpu act before the arrays are materialized... so it's an error:
258258
@test_throws ErrorException @eval @autosize (1,2,3) Dense(_=>2) |> f64
259259
end
260+
261+
@testset "type matching" begin
262+
# Check that _match_eltype doesn't replace this with an array of Float32:
263+
@test Flux._match_eltype(Dense(2=>3), fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil}
264+
# For RNN etc there's a special path:
265+
@test RNN(2=>3)(fill(Flux.Nil(),2,4)) isa Matrix{Flux.Nil}
266+
end

0 commit comments

Comments
 (0)