Skip to content

Commit fd369e4

Browse files
more embedding tests; keep Embedding unexported
1 parent 062fc09 commit fd369e4

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.12.4"
3+
version = "0.12.5"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
11+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1012
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1113
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1214
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -37,7 +39,7 @@ Colors = "0.12"
3739
Functors = "0.2.1"
3840
Juno = "0.8"
3941
MacroTools = "0.5"
40-
NNlib = "0.7.14"
42+
NNlib = "0.7.24"
4143
NNlibCUDA = "0.1"
4244
Reexport = "0.2, 1.0"
4345
StatsBase = "0.33"

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
1212

1313
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
14-
RNN, LSTM, GRU, Embedding,
14+
RNN, LSTM, GRU,
1515
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1616
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1717
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,

test/cuda/layers.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,16 +261,37 @@ end
261261
end
262262

263263
@testset "Embedding" begin
264-
vocab_size, embed_size = 10, 4
264+
vocab_size, embed_size = 5, 2
265265
m = Embedding(vocab_size, embed_size)
266-
x = rand(1:vocab_size, 3)
266+
267+
x = [1, 3, 5]
267268
y = m(x)
268269
m_g = m |> gpu
269270
x_g = x |> gpu
270271
y_g = m_g(x_g)
271272
@test collect(y_g) == y
273+
274+
gs = gradient(() -> sum(m(x)), params(m))
275+
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
276+
@test collect(gs_g[m_g.weight]) gs[m.weight]
277+
272278
gs = gradient(() -> sum(tanh.(m(x))), params(m))
273279
gs_g = gradient(() -> sum(tanh.(m_g(x_g))), params(m_g))
274280
@test collect(gs_g[m_g.weight]) gs[m.weight]
281+
282+
@testset "repeated indexes" begin
283+
vocab_size, embed_size = 5, 2
284+
m = Embedding(vocab_size, embed_size)
285+
286+
x = [1, 3, 5, 3] # repeated indexes
287+
y = m(x)
288+
m_g = m |> gpu
289+
x_g = x |> gpu
290+
y_g = m_g(x_g)
291+
@test collect(y_g) == y
292+
gs = gradient(() -> sum(m(x)), params(m))
293+
gs_g = gradient(() -> sum(m_g(x_g)), params(m_g))
294+
@test collect(gs_g[m_g.weight]) gs[m.weight]
295+
end
275296
end
276297

0 commit comments

Comments
 (0)