Skip to content

Commit 35ab120

Browse files
committed
fixed tests, using only onehot for adjoint
1 parent 51b7c00 commit 35ab120

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/onehot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
232232
return NNlib.gather(A, _indices(B))
233233
end
234234

235-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotLike{<:Any, L, 1}}) where L
235+
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
236236
B_dim = length(_indices(parent(B)))
237237
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
238238
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))

test/onehot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ end
5959
@test A * b3' == A * Array(b3')
6060
@test transpose(X) * b2 == transpose(X) * Array(b2)
6161
@test A * b4 == A[:,[1, 2, 2, 2, 3, 1]]
62-
@test_broken A * b5' == A[:,[1, 2, 2, 2, 3, 1]]
62+
@test A * b5' == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
6363
@test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2])
64-
@test_broken A * b7'
64+
@test A * b7' == A[:,[1, 2, 3, 1, 2, 3]]
6565

6666
@test_throws DimensionMismatch A*b1'
6767
@test_throws DimensionMismatch A*b2

0 commit comments

Comments
 (0)