Skip to content

Commit 9ce7347

Browse files
extend to generic arrays; add cuda tests
1 parent dcadaf4 commit 9ce7347

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

src/layers/basic.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,22 @@ end
143143

144144
@functor Dense
145145

146+
<<<<<<< HEAD
146147
function (a::Dense)(x::AbstractVecOrMat)
147148
W, b, σ = a.weight, a.bias, a.σ
148149
return σ.(W*x .+ b)
149150
end
150151

151152
(a::Dense)(x::AbstractArray) =
152153
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
154+
=======
155+
function (a::Dense)(x::Union{AbstractVector, AbstractMatrix})
156+
W, b, σ = a.W, a.b, a.σ
157+
return σ.(W*x .+ b)
158+
end
159+
160+
(a::Dense)(x::AbstractArray) = reshape(a(mat(x)), :, size(x)[2:end]...)
161+
>>>>>>> 017acdf9 (extend to generic arrays; add cuda tests)
153162

154163
function Base.show(io::IO, l::Dense)
155164
print(io, "Dense(", size(l.weight, 2), ", ", size(l.weight, 1))
@@ -467,10 +476,9 @@ function Embedding(in::Integer, out::Integer;
467476
return Embedding(init(out, in))
468477
end
469478

470-
(m::Embedding)(x::OneHotMatrix) = m.weight * x # equivalent to m.weight[:, onecold(x)]
471-
(m::Embedding)(x::OneHotVector) = m.weight * x
472-
(m::Embedding)(x::AbstractVector) = m.weight[:, x]
473-
(m::Embedding)(x::Int) = m.weight[:, x]
479+
(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
480+
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
481+
(m::Embedding)(x::AbstractArray) = reshape(m(mat(x)), :, size(x)[2:end]...)
474482

475483
function Base.show(io::IO, m::Embedding)
476484
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")

src/layers/stateless.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Reshape arbitrarly-shaped input into a matrix-shaped output,
55
preserving the size of the last dimension.
66
7-
See also [`unsqueeze`](@ref).
7+
See also [`unsqueeze`](@ref) and [`mat`](@ref).
88
99
# Examples
1010
```jldoctest
@@ -26,6 +26,18 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29+
"""
30+
mat(x::AbstractArray)
31+
32+
Reshape arbitrarly-shaped input into a matrix-shaped output,
33+
preserving the size of the first dimension.
34+
35+
See also [`flatten`](@ref) and [`unsqueeze`](@ref).
36+
"""
37+
function mat(x::AbstractArray)
38+
return reshape(x, size(x,1), :)
39+
end
40+
2941
"""
3042
normalise(x; dims=ndims(x), ϵ=1e-5)
3143

test/cuda/layers.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,17 @@ end
258258
@test gs_cpu[pcpu] gs_gpu[pgpu]
259259
end
260260
end
261+
262+
@testset "Embedding" begin
263+
vocab_size, embed_size = 10, 4
264+
m = Embedding(vocab_size, embed_size)
265+
x = rand(1:vocab_size, 3)
266+
y = m(x)
267+
m_g = m |> gpu
268+
x_g = x |> gpu
269+
y_g = m_g(x_g)
270+
@test collect(y_g) == y
271+
gs = gradient(() -> sum(tanh.(m(x))), params(m))
272+
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
273+
@test collect(gs_g[m_g.weight]) gs[m.weight]
261274
end

0 commit comments

Comments
 (0)