Skip to content

Commit c9627c5

Browse files
committed
Use view for RNN gate slice extraction
This was originally passed over in #907, but I don't find that argument particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations.
1 parent ea26f45 commit c9627c5

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/layers/recurrent.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
gate(h, n) = (1:h) .+ h*(n-1)
33
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
4-
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
4+
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
55

66
# Stateful recurrence
77

@@ -97,7 +97,7 @@ struct RNNCell{F,A,V,S}
9797
state0::S
9898
end
9999

100-
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
100+
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
101101
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
102102

103103
function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
@@ -194,7 +194,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
194194
elseif sym === :c
195195
Zygote.ignore() do
196196
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
197-
end
197+
end
198198
return getfield(m, :state0)[2]
199199
else
200200
return getfield(m, sym)
@@ -273,7 +273,7 @@ struct GRUv3Cell{A,V,S}
273273
end
274274

275275
GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
276-
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
276+
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
277277
init(out, out), init_state(out,1))
278278

279279
function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}

0 commit comments

Comments
 (0)