Skip to content

Commit fe100d5

Browse files
committed
updated Embedding tests
1 parent b84c56d commit fe100d5

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

test/cuda/layers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,12 @@ gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
124124

125125
embedding = [Flux.Embedding]
126126
gpu_gradtest("Embedding", embedding, [1,3,5], 5 => 2)
127-
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5 => 2)
127+
gpu_gradtest("Embedding repeated indices", embedding, rand(1:50, 10^6), 50 => 2)
128128
gpu_gradtest("Embedding integer index", embedding, 1, 5 => 2)
129129
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5 => 2)
130130
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5 => 2)
131131
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5 => 2)
132-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5 => 2)
132+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix(rand(1:50, 10^6), 50), 50 => 2)
133133

134134
@testset "function layers" begin
135135
x = rand(Float32, 3,3)

test/layers/basic.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ import Flux: activations
202202
@test y isa Matrix{Float32}
203203
@test y m.weight[:,x]
204204
x2 = OneHotMatrix(x, vocab_size)
205-
y2 = m(x2)
206-
@test y2 isa Matrix{Float32}
207-
@test y2 y
205+
@test m(x2) isa Matrix{Float32}
206+
@test m(x2) y
207+
@test m(collect(x2)) y
208208
@test_throws DimensionMismatch m(OneHotMatrix(x, 1000))
209209

210210
x = rand(1:vocab_size, 3, 4)
@@ -215,5 +215,8 @@ import Flux: activations
215215
@test m(2) m.weight[:,2]
216216
@test m(OneHotVector(3, vocab_size)) m.weight[:,3]
217217
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
218+
219+
x = onehotbatch(rand(1:vocab_size, 4, 3, 4, 5), 1:vocab_size)
220+
@test m(x) m(onecold(x))
218221
end
219222
end

0 commit comments

Comments
 (0)