Skip to content

Commit 017acdf

Browse files
extend to generic arrays; add cuda tests
1 parent 75e90c0 commit 017acdf

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

src/layers/basic.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,13 @@ end
120120

121121
@functor Dense
122122

123-
function (a::Dense)(x::AbstractArray)
123+
function (a::Dense)(x::Union{AbstractVector, AbstractMatrix})
124124
W, b, σ = a.W, a.b, a.σ
125-
sz = size(x)
126-
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
127-
x = σ.(W*x .+ b)
128-
return reshape(x, :, sz[2:end]...)
125+
return σ.(W*x .+ b)
129126
end
130127

128+
(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...)
129+
131130
function Base.show(io::IO, l::Dense)
132131
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
133132
l.σ == identity || print(io, ", ", l.σ)
@@ -326,7 +325,7 @@ function Base.show(io::IO, l::Bilinear)
326325
end
327326

328327
"""
329-
Parallel(connection, layers...)
328+
Parallel(connection, layers...)
330329
331330
Create a 'Parallel' layer that passes an input array to each path in
332331
`layers`, reducing the output with `connection`.
@@ -416,10 +415,9 @@ function Embedding(in::Integer, out::Integer;
416415
return Embedding(init(out, in))
417416
end
418417

419-
(m::Embedding)(x::OneHotMatrix) = m.weight * x # equivalent to m.weight[:, onecold(x)]
420-
(m::Embedding)(x::OneHotVector) = m.weight * x
421-
(m::Embedding)(x::AbstractVector) = m.weight[:, x]
422-
(m::Embedding)(x::Int) = m.weight[:, x]
418+
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
419+
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
420+
(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...)
423421

424422
function Base.show(io::IO, m::Embedding)
425423
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")

src/layers/stateless.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Reshape arbitrarly-shaped input into a matrix-shaped output,
55
preserving the size of the last dimension.
66
7-
See also [`unsqueeze`](@ref).
7+
See also [`unsqueeze`](@ref) and [`mat`](@ref).
88
99
# Examples
1010
```jldoctest
@@ -26,6 +26,18 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29+
"""
30+
mat(x::AbstractArray)
31+
32+
Reshape arbitrarly-shaped input into a matrix-shaped output,
33+
preserving the size of the first dimension.
34+
35+
See also [`flatten`](@ref) and [`unsqueeze`](@ref).
36+
"""
37+
function mat(x::AbstractArray)
38+
return reshape(x, size(x,1), :)
39+
end
40+
2941
"""
3042
normalise(x; dims=ndims(x), ϵ=1e-5)
3143

test/cuda/layers.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,17 @@ end
218218
@test gs_cpu[pcpu] gs_gpu[pgpu]
219219
end
220220
end
221+
222+
@testset "Embedding" begin
223+
vocab_size, embed_size = 10, 4
224+
m = Embedding(vocab_size, embed_size)
225+
x = rand(1:vocab_size, 3)
226+
y = m(x)
227+
m_g = m |> gpu
228+
x_g = x |> gpu
229+
y_g = m_g(x_g)
230+
@test collect(y_g) == y
231+
gs = gradient(() -> sum(tanh.(m(x))), params(m))
232+
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
233+
@test collect(gs_g[m_g.weight]) gs[m.weight]
221234
end

0 commit comments

Comments
 (0)