Skip to content

Commit 64884e9

Browse files
bors[bot]gxyddarsnack
authored
Merge #1424
1424: multiplication of {Transpose, Adjoint} of Array and OneHotVector r=darsnack a=gxyd Also fixes #777 . Tests have been added. Co-authored-by: Gaurav Dhingra <gauravdhingra.gxyd@gmail.com> Co-authored-by: Kyle Daruwalla <daruwalla@wisc.edu>
2 parents 4be8443 + e42a0d5 commit 64884e9

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/onehot.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,20 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
187187
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
188188
return A[:, onecold(B)]
189189
end
190+
for wrapper in [:Adjoint, :Transpose]
191+
@eval begin
192+
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}
193+
size(A, 2) == L ||
194+
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
195+
196+
return A[:, onecold(b)]
197+
end
198+
199+
function Base.:*(A::$wrapper{<:Number, <:AbstractVector{T}}, b::OneHotVector{<:Any, L}) where {L, T}
200+
size(A, 2) == L ||
201+
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
202+
203+
return A[onecold(b)]
204+
end
205+
end
206+
end

test/onehot.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,17 @@ end
2727

2828
@testset "abstractmatrix onehotvector multiplication" begin
2929
A = [1 3 5; 2 4 6; 3 6 9]
30+
v = [1, 2, 3, 4, 5]
31+
X = reshape(v, (5, 1))
3032
b1 = Flux.OneHotVector(1, 3)
3133
b2 = Flux.OneHotVector(3, 5)
3234

3335
@test A*b1 == A[:,1]
36+
@test b1' * A == Array(b1') * A
37+
@test A' * b1 == A' * Array(b1)
38+
@test v' * b2 == v' * Array(b2)
39+
@test transpose(X) * b2 == transpose(X) * Array(b2)
40+
@test transpose(v) * b2 == transpose(v) * Array(b2)
3441
@test_throws DimensionMismatch A*b2
3542
end
3643

@@ -132,4 +139,4 @@ end
132139
@test map(identity, oa) == oa
133140
@test map(x -> 2 * x, oa) == 2 .* oa
134141
end
135-
end
142+
end

0 commit comments

Comments
 (0)