Skip to content

Commit 75e3771

Browse files
bors[bot]ToucheSir
andauthored
Merge #1761
1761: Use view for RNN gate slice extraction r=mcabbott a=ToucheSir This was originally passed over in #907, but I don't find the argument in that PR particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations. Pinging `@jeremiedb,` `@sdobber` and `@mkschleg` if they have any interesting benchmarks to run this on. Otherwise I'll try to get something working with https://github.com/FluxML/Flux.jl/blob/master/perf/bench_utils.jl locally. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [N/A] Documentation, if applicable - [N/A] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
2 parents 080114d + 2e0bb1d commit 75e3771

File tree

2 files changed

+51
-33
lines changed

2 files changed

+51
-33
lines changed

src/layers/recurrent.jl

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
11

22
gate(h, n) = (1:h) .+ h*(n-1)
33
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
4-
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
4+
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
5+
6+
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
7+
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
8+
9+
@adjoint function multigate(x::AbstractArray, h, c)
10+
function multigate_pullback(dy)
11+
dx = Zygote._zero(x, eltype(x))
12+
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13+
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
14+
end
15+
return (dx, nothing, nothing)
16+
end
17+
return multigate(x, h, c), multigate_pullback
18+
end
19+
20+
reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...)
521

622
# Stateful recurrence
723

@@ -97,14 +113,13 @@ struct RNNCell{F,A,V,S}
97113
state0::S
98114
end
99115

100-
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
116+
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
101117
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
102118

103119
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
104120
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
105121
h = σ.(Wi*x .+ Wh*h .+ b)
106-
sz = size(x)
107-
return h, reshape(h, :, sz[2:end]...)
122+
return h, reshape_cell_output(h, x)
108123
end
109124

110125
@functor RNNCell
@@ -208,14 +223,10 @@ end
208223
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
209224
b, o = m.b, size(h, 1)
210225
g = m.Wi*x .+ m.Wh*h .+ b
211-
input = σ.(gate(g, o, 1))
212-
forget = σ.(gate(g, o, 2))
213-
cell = tanh.(gate(g, o, 3))
214-
output = σ.(gate(g, o, 4))
215-
c = forget .* c .+ input .* cell
216-
h′ = output .* tanh.(c)
217-
sz = size(x)
218-
return (h′, c), reshape(h′, :, sz[2:end]...)
226+
input, forget, cell, output = multigate(g, o, Val(4))
227+
c′ = @. σ(forget) * c + σ(input) * tanh(cell)
228+
h′ = @. σ(output) * tanh(c′)
229+
return (h′, c′), reshape_cell_output(h′, x)
219230
end
220231

221232
@functor LSTMCell
@@ -269,7 +280,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
269280
elseif sym === :c
270281
Zygote.ignore() do
271282
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
272-
end
283+
end
273284
return getfield(m, :state0)[2]
274285
else
275286
return getfield(m, sym)
@@ -278,13 +289,10 @@ end
278289

279290
# GRU
280291

281-
function _gru_output(Wi, Wh, b, x, h)
282-
o = size(h, 1)
283-
gx, gh = Wi*x, Wh*h
284-
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
285-
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
286-
287-
return gx, gh, r, z
292+
function _gru_output(gxs, ghs, bs)
293+
r = @. σ(gxs[1] + ghs[1] + bs[1])
294+
z = @. σ(gxs[2] + ghs[2] + bs[2])
295+
return r, z
288296
end
289297

290298
struct GRUCell{A,V,S}
@@ -298,12 +306,12 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
298306
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
299307

300308
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
301-
b, o = m.b, size(h, 1)
302-
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
303-
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
304-
h′ = (1 .- z) .*.+ z .* h
305-
sz = size(x)
306-
return h′, reshape(h′, :, sz[2:end]...)
309+
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
310+
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
311+
r, z = _gru_output(gxs, ghs, bs)
312+
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
313+
h′ = @. (1 - z) *+ z * h
314+
return h′, reshape_cell_output(h′, x)
307315
end
308316

309317
@functor GRUCell
@@ -372,16 +380,16 @@ struct GRUv3Cell{A,V,S}
372380
end
373381

374382
GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
375-
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
383+
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
376384
init(out, out), init_state(out,1))
377385

378386
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
379-
b, o = m.b, size(h, 1)
380-
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
381-
= tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
382-
h′ = (1 .- z) .*.+ z .* h
383-
sz = size(x)
384-
return h′, reshape(h′, :, sz[2:end]...)
387+
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
388+
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
389+
r, z = _gru_output(gxs, ghs, bs)
390+
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
391+
h′ = @. (1 - z) *+ z * h
392+
return h′, reshape_cell_output(h′, x)
385393
end
386394

387395
@functor GRUv3Cell

test/layers/recurrent.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,13 @@ end
9696
@test_throws MethodError m(x)
9797
end
9898
end
99+
100+
@testset "multigate" begin
101+
x = rand(6, 5)
102+
res, (dx,) = Flux.withgradient(x) do x
103+
x1, _, x3 = Flux.multigate(x, 2, Val(3))
104+
sum(x1) + sum(x3 .* 2)
105+
end
106+
@test res == sum(x[1:2, :]) + 2sum(x[5:6, :])
107+
@test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)]
108+
end

0 commit comments

Comments
 (0)