Skip to content

Commit 8e21d79

Browse files
committed
update Embedding constructor
Updated Embedding constructor to use `=>` and added OneHotLikeVector and OneHotLikeMatrix consts.
1 parent 1e2cffd commit 8e21d79

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

src/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ function Base.show(io::IO, m::Parallel)
476476
end
477477

478478
"""
479-
Embedding(in, out; init=randn)
479+
Embedding(in => out; init=randn)
480480
481481
A lookup table that stores embeddings of dimension `out`
482482
for a vocabulary of size `in`.
@@ -511,7 +511,7 @@ end
511511

512512
@functor Embedding
513513

514-
Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))
514+
Embedding(dims::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(last(dims), first(dims)))
515515

516516

517517
(m::Embedding)(x::Integer) = m.weight[:, x]

src/onehot.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ const OneHotLike{T, L, N, var"N+1", I} =
3333
Union{OneHotArray{T, L, N, var"N+1", I},
3434
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}
3535

36+
const OneHotLikeVector{T, L} = OneHotLike{T, L, 0, 1, T}
37+
const OneHotLikeMatrix{T, L, I} = OneHotLike{T, L, 1, 2, I}
38+
3639
_isonehot(x::OneHotArray) = true
3740
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)
3841

test/cuda/layers.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
127127
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
128128

129129
embedding = [Flux.Embedding]
130-
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
131-
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
132-
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
133-
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
134-
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
135-
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
136-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
130+
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
131+
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2)
132+
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
133+
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
134+
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
135+
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
136+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2)
137137

138138
@testset "function layers" begin
139139
x = rand(Float32, 3,3)

test/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ import Flux: activations
276276

277277
@testset "Embedding" begin
278278
vocab_size, embed_size = 10, 4
279-
m = Flux.Embedding(vocab_size, embed_size)
279+
m = Flux.Embedding(vocab_size => embed_size)
280280
@test size(m.weight) == (embed_size, vocab_size)
281281

282282
x = rand(1:vocab_size, 3)

0 commit comments

Comments
 (0)