Skip to content

Commit 51b7c00

Browse files
committed
added many tests for reshaped matrices
1 parent 775efee commit 51b7c00

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/onehot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
233233
end
234234

235235
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotLike{<:Any, L, 1}}) where L
236-
B_dim = length(parent(B).indices)
236+
B_dim = length(_indices(parent(B)))
237237
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
238238
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
239239
end

test/onehot.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,26 @@ end
4848
b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3)
4949
b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
5050
b3 = Flux.OneHotMatrix([1, 1, 2], 4)
51+
b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
52+
b5 = reshape(b4, 6, :)
53+
b6 = reshape(Flux.OneHotMatrix([1 2 2; 2 2 1], 2), 3, :)
54+
b7 = reshape(Flux.OneHotMatrix([1 2 3; 1 2 3], 3), 6, :)
5155

5256
@test A * b1 == A[:,[1, 1, 2, 2]]
5357
@test b1' * A == Array(b1') * A
5458
@test A' * b1 == A' * Array(b1)
5559
@test A * b3' == A * Array(b3')
5660
@test transpose(X) * b2 == transpose(X) * Array(b2)
61+
@test A * b4 == A[:,[1, 2, 2, 2, 3, 1]]
62+
@test_broken A * b5' == A[:,[1, 2, 2, 2, 3, 1]]
63+
@test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2])
64+
@test_broken A * b7'
65+
5766
@test_throws DimensionMismatch A*b1'
5867
@test_throws DimensionMismatch A*b2
5968
@test_throws DimensionMismatch A*b2'
69+
@test_throws DimensionMismatch A*b6'
70+
@test_throws DimensionMismatch A*b7
6071
end
6172

6273
@testset "OneHotArray" begin

0 commit comments

Comments
 (0)