Skip to content

Commit a6c9c11

Browse files
committed
add multigate
1 parent c9627c5 commit a6c9c11

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

src/layers/recurrent.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@ gate(h, n) = (1:h) .+ h*(n-1)
33
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
44
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
55

6+
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
7+
8+
@adjoint function multigate(x::AbstractArray, h, c)
9+
function multigate_pullback(dy)
10+
dx = Zygote._zero(x, eltype(x))
11+
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
12+
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
13+
end
14+
return (dx, nothing, nothing)
15+
end
16+
return multigate(x, h, c), multigate_pullback
17+
end
18+
619
# Stateful recurrence
720

821
"""
@@ -157,12 +170,9 @@ end
157170
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
158171
b, o = m.b, size(h, 1)
159172
g = m.Wi*x .+ m.Wh*h .+ b
160-
input = σ.(gate(g, o, 1))
161-
forget = σ.(gate(g, o, 2))
162-
cell = tanh.(gate(g, o, 3))
163-
output = σ.(gate(g, o, 4))
164-
c = forget .* c .+ input .* cell
165-
h′ = output .* tanh.(c)
173+
input, forget, cell, output = multigate(g, o, Val(4))
174+
c = @. σ(forget) * c + σ(input) * tanh(cell)
175+
h′ = @. σ(output) * tanh(c)
166176
sz = size(x)
167177
return (h′, c), reshape(h′, :, sz[2:end]...)
168178
end
@@ -203,13 +213,10 @@ end
203213

204214
# GRU
205215

206-
function _gru_output(Wi, Wh, b, x, h)
207-
o = size(h, 1)
208-
gx, gh = Wi*x, Wh*h
209-
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
210-
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
211-
212-
return gx, gh, r, z
216+
function _gru_output(gxs, ghs, bs)
217+
r = @. σ(gxs[1] + ghs[1] + bs[1])
218+
z = @. σ(gxs[2] + ghs[2] + bs[2])
219+
return r, z
213220
end
214221

215222
struct GRUCell{A,V,S}
@@ -223,10 +230,11 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
223230
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))
224231

225232
function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
226-
b, o = m.b, size(h, 1)
227-
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
228-
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
229-
h′ = (1 .- z) .*.+ z .* h
233+
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
234+
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
235+
r, z = _gru_output(gxs, ghs, bs)
236+
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
237+
h′ = @. (1 - z) *+ z * h
230238
sz = size(x)
231239
return h′, reshape(h′, :, sz[2:end]...)
232240
end
@@ -277,10 +285,11 @@ GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32)
277285
init(out, out), init_state(out,1))
278286

279287
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
280-
b, o = m.b, size(h, 1)
281-
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
282-
= tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
283-
h′ = (1 .- z) .*.+ z .* h
288+
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
289+
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
290+
r, z = _gru_output(gxs, ghs, bs)
291+
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
292+
h′ = @. (1 - z) *+ z * h
284293
sz = size(x)
285294
return h′, reshape(h′, :, sz[2:end]...)
286295
end

test/layers/recurrent.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,13 @@ end
9696
@test_throws MethodError m(x)
9797
end
9898
end
99+
100+
@testset "multigate" begin
101+
x = rand(6, 5)
102+
res, (dx,) = Flux.withgradient(x) do x
103+
x1, _, x3 = Flux.multigate(x, 2, Val(3))
104+
sum(x1) + sum(x3 .* 2)
105+
end
106+
@test res == sum(x[1:2, :]) + 2sum(x[5:6, :])
107+
@test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)]
108+
end

0 commit comments

Comments
 (0)