Skip to content

Commit 0a67472

Browse files
committed
using different method for onehot vector and onehot matrix
1 parent 7ce132f commit 0a67472

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/onehot.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,13 @@ end
220220

221221
@nograd OneHotArray, onecold, onehot, onehotbatch
222222

223-
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
223+
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 0}) where L
224+
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
225+
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
226+
return A[:, onecold(B)]
227+
end
228+
229+
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
224230
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
225231
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
226232
return NNlib.gather(A, _indices(B))

test/onehot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232
b1 = Flux.OneHotVector(1, 3)
3333
b2 = Flux.OneHotVector(3, 5)
3434

35-
@test A*b1 == A[:,1]
35+
@test A * b1 == A[:,1]
3636
@test b1' * A == Array(b1') * A
3737
@test A' * b1 == A' * Array(b1)
3838
@test v' * b2 == v' * Array(b2)
@@ -49,7 +49,7 @@ end
4949
b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
5050
b3 = Flux.OneHotMatrix([1, 1, 2], 4)
5151

52-
@test A*b1 == A[:,[1, 1, 2, 2]]
52+
@test A * b1 == A[:,[1, 1, 2, 2]]
5353
@test b1' * A == Array(b1') * A
5454
@test A' * b1 == A' * Array(b1)
5555
@test A * b3' == A * Array(b3')

0 commit comments

Comments
 (0)