Skip to content

Commit 2507126

Browse files
committed
adding optimization of multiplication by adjoint
1 parent 5f7ce6b commit 2507126

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/onehot.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotMatrix{<:Any, L}) where L
231231
return NNlib.gather(A, B.indices)
232232
end
233233

234+
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix{<:Any, L}}) where L
235+
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
236+
return NNlib.scatter(+, A, parent(B).indices, dstsize=(size(A,1), size(B,2)))
237+
end
238+
234239
for wrapper in [:Adjoint, :Transpose]
235240
@eval begin
236241
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}

test/onehot.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ end
5454
@test A * b1' == A * Array(b1')
5555
@test transpose(X) * b2 == transpose(X) * Array(b2)
5656
@test_throws DimensionMismatch A*b2
57+
@test_throws DimensionMismatch A*b2'
5758
end
5859

5960
@testset "OneHotArray" begin

0 commit comments

Comments
 (0)