Skip to content

Commit 775efee

Browse files
committed
dispatching on OneHotLike of dimension 2
1 parent 6e2da25 commit 775efee

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/onehot.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
232232
return NNlib.gather(A, _indices(B))
233233
end
234234

235-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
235+
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotLike{<:Any, L, 1}}) where L
236236
B_dim = length(parent(B).indices)
237237
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
238238
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))

0 commit comments

Comments
 (0)