Skip to content

Commit 5ac7a3d

Browse files
committed
fixed dimension check, added tests to check different dimensionality
1 parent 644bd1a commit 5ac7a3d

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/onehot.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,9 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotMatrix{<:Any, L}) where L
231231
return NNlib.gather(A, B.indices)
232232
end
233233

234-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix{<:Any, L}}) where L
235-
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
234+
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
235+
B_dim = length(parent(B).indices)
236+
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
236237
return NNlib.scatter(+, A, parent(B).indices, dstsize=(size(A,1), size(B,2)))
237238
end
238239

test/onehot.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ end
4545
A = [1 3 5; 2 4 6; 3 6 9]
4646
v = [1, 2, 3, 4, 5]
4747
X = reshape(v, (5, 1))
48-
b1 = Flux.OneHotMatrix([1, 1, 2], 3)
48+
b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3)
4949
b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
50+
b3 = Flux.OneHotMatrix([1, 1, 2], 4)
5051

51-
@test A*b1 == A[:,[1, 1, 2]]
52+
@test A*b1 == A[:,[1, 1, 2, 2]]
5253
@test b1' * A == Array(b1') * A
5354
@test A' * b1 == A' * Array(b1)
54-
@test A * b1' == A * Array(b1')
55+
@test A * b3' == A * Array(b3')
5556
@test transpose(X) * b2 == transpose(X) * Array(b2)
57+
@test_throws DimensionMismatch A*b1'
5658
@test_throws DimensionMismatch A*b2
5759
@test_throws DimensionMismatch A*b2'
5860
end

0 commit comments

Comments
 (0)