Skip to content

Commit 383bc02

Browse files
committed
updated Embedding tests
1 parent 9d6bed7 commit 383bc02

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

test/cuda/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
128128

129129
embedding = [Flux.Embedding]
130130
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
131-
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2)
131+
gpu_gradtest("Embedding repeated indices", embedding, rand(1:50, 10^6), 50 => 2)
132132
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
133133
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
134134
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
135135
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
136-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2)
136+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:50, 10^6), 50), 50 => 2)
137137

138138
@testset "function layers" begin
139139
x = rand(Float32, 3,3)

test/layers/basic.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,9 @@ import Flux: activations
284284
@test y isa Matrix{Float32}
285285
@test y m.weight[:,x]
286286
x2 = OneHotMatrix(x, vocab_size)
287-
y2 = m(x2)
288-
@test y2 isa Matrix{Float32}
289-
@test y2 y
287+
@test m(x2) isa Matrix{Float32}
288+
@test m(x2) y
289+
@test m(collect(x2)) y
290290
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
291291

292292
x = rand(1:vocab_size, 3, 4)
@@ -297,6 +297,9 @@ import Flux: activations
297297
@test m(2) m.weight[:,2]
298298
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
299299
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
300+
301+
x = onehotbatch(rand(1:vocab_size, 4, 3, 4, 5), 1:vocab_size)
302+
@test m(x) m(onecold(x))
300303
end
301304
end
302305

0 commit comments

Comments
 (0)