|
1 | 1 |
|
2 | 2 | gate(h, n) = (1:h) .+ h*(n-1)
|
3 | 3 | gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
|
4 |
| -gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] |
| 4 | +gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :) |
5 | 5 |
|
6 | 6 | # Stateful recurrence
|
7 | 7 |
|
@@ -97,7 +97,7 @@ struct RNNCell{F,A,V,S}
|
97 | 97 | state0::S
|
98 | 98 | end
|
99 | 99 |
|
100 |
| -RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = |
| 100 | +RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) = |
101 | 101 | RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))
|
102 | 102 |
|
103 | 103 | function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
|
@@ -194,7 +194,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
|
194 | 194 | elseif sym === :c
|
195 | 195 | Zygote.ignore() do
|
196 | 196 | @warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
|
197 |
| - end |
| 197 | + end |
198 | 198 | return getfield(m, :state0)[2]
|
199 | 199 | else
|
200 | 200 | return getfield(m, sym)
|
@@ -273,7 +273,7 @@ struct GRUv3Cell{A,V,S}
|
273 | 273 | end
|
274 | 274 |
|
275 | 275 | GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
|
276 |
| - GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), |
| 276 | + GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3), |
277 | 277 | init(out, out), init_state(out,1))
|
278 | 278 |
|
279 | 279 | function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
|
|
0 commit comments