Skip to content

Commit f563f48

Browse files
authored
add missing mul! methods with transpose (#56)
* some more tests * reorganize gpu tests * oops, fix tests
1 parent 4e2d74a commit f563f48

File tree

4 files changed

+49
-12
lines changed

4 files changed

+49
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.8"
3+
version = "0.2.9"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/linalg.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
1010
return NNlib.gather(A, _indices(B))
1111
end
1212

13-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
14-
B_dim = length(_indices(parent(B)))
15-
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
16-
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
17-
end
18-
1913
for wrapper in [:Adjoint, :Transpose]
2014
@eval begin
2115
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
@@ -31,6 +25,18 @@ for wrapper in [:Adjoint, :Transpose]
3125

3226
return A[onecold(b)]
3327
end
28+
29+
# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
30+
function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix})
31+
if size(A,2) size(B,1)
32+
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
33+
end
34+
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
35+
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
36+
end
37+
fill!(Y, zero(eltype(Y)))
38+
return NNlib.scatter!(+, Y, A, _indices(parent(B)))
39+
end
3440
end
3541
end
3642

test/gpu.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ end
4848
y = onehotbatch(ones(3), 1:2) |> cu;
4949
@test (repr("text/plain", y); true)
5050

51-
gA = rand(3, 2) |> cu;
52-
53-
#NOTE: this would require something that can copute gradient... we don't have that here?
51+
#NOTE: this would require something that can compute gradient... we don't have that here?
5452
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
53+
end
54+
55+
@testset "LinearAlgebra" begin
56+
y = onehotbatch(ones(3), 1:2) |> cu;
57+
gA = rand(3, 2) |> cu;
5558

5659
# some specialized implementations call only mul! and not *, so we must ensure this works
5760
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) gA*y
@@ -66,6 +69,14 @@ end
6669
y = reshape(y, 3, 2)
6770
gA = rand(2, 3) |> cu
6871
@test_broken LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) gA*y
72+
73+
A = cu([1 3 5; 2 4 6; 3 6 9])
74+
b3_dense = cu(Array(OneHotMatrix([1, 1, 2], 4)))
75+
b3 = OneHotMatrix(cu([1, 1, 2]), 4)
76+
77+
d1 = fill(NaN, 3, 4) |> cu
78+
@test mul!(d1, A, b3') == A * b3_dense'
79+
@test mul!(d1, A, transpose(b3)) == A * transpose(b3_dense)
6980
end
7081

7182
@testset "onehotbatch(::CuArray, ::UnitRange)" begin

test/linalg.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,36 @@ end
5959
c1 = fill(NaN, 3, 4)
6060
@test mul!(c1, A, b1) == A * b1
6161
@test c1 == A * b1
62-
62+
6363
c4 = fill(NaN, 3, 6)
6464
@test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot
6565
@test mul!(c4, A', b4) == A' * b4
6666
c6 = fill(NaN, 3, 4)
6767
@test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot
6868
@test mul!(c6, A', b6) == A' * b6
69-
69+
7070
@test_throws DimensionMismatch mul!(c1, A, b2)
7171
@test_throws DimensionMismatch mul!(c1, A, b4)
7272
@test_throws DimensionMismatch mul!(c4, A, b1)
7373
@test_throws DimensionMismatch mul!(zeros(10, 3), A, b1)
74+
75+
# note that we have separate implementations for a couple of mul! for the time being
76+
77+
d1 = fill(NaN, 3, 4)
78+
@test mul!(d1, A, b3') == A * Array(b3')
79+
@test mul!(d1, A, transpose(b3)) == A * Array(transpose(b3))
80+
81+
d2 = fill(NaN, 3, 6)
82+
@test mul!(d2, A, b5') == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
83+
@test mul!(d2, A, transpose(b5)) == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
84+
85+
d3 = fill(NaN, 3, 6)
86+
@test mul!(d3, A, b7') == A[:,[1, 2, 3, 1, 2, 3]]
87+
@test mul!(d3, A, transpose(b7)) == A[:,[1, 2, 3, 1, 2, 3]]
88+
89+
d4 = fill(NaN, 4, 4)
90+
@test_throws DimensionMismatch mul!(d4, A, b3')
91+
@test_throws DimensionMismatch mul!(d4, A, transpose(b3))
92+
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), b3')
93+
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), transpose(b3))
7494
end

0 commit comments

Comments
 (0)