Skip to content

Commit 69afb67

Browse files
bors[bot]racinmat
andauthored
Merge #1756
1756: Speedup and fix of multiplication by OneHotMatrix r=CarloLucibello a=racinmat ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [ ] Documentation, if applicable - [ ] API changes require approval from a committer (different from the author, if applicable) Fixes #1355 . Also fixes bug mentioned in #1355 (comment). Adds tests for both gpu and cpu. Adds multiplication by sparse matrix to benchmarks. Co-authored-by: Matěj Račinský <matej.racinsky@avast.com> Co-authored-by: Matěj Račinský <racinsky.matej@seznam.cz>
2 parents f66be89 + 35ab120 commit 69afb67

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Flux Release Notes
22

3+
## v0.12.8
4+
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)
5+
36
## v0.12.7
47
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)
58
* The layers within `Chain` and `Parallel` may now [have names](https://github.com/FluxML/Flux.jl/issues/1680).

src/onehot.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Adapt
22
import .CUDA
3+
using LinearAlgebra, NNlib
34

45
"""
56
OneHotArray{T,L,N,M,I} <: AbstractArray{Bool,M}
@@ -224,6 +225,19 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
224225
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
225226
return A[:, onecold(B)]
226227
end
228+
229+
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L, 1}) where L
230+
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
231+
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
232+
return NNlib.gather(A, _indices(B))
233+
end
234+
235+
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
236+
B_dim = length(_indices(parent(B)))
237+
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
238+
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
239+
end
240+
227241
for wrapper in [:Adjoint, :Transpose]
228242
@eval begin
229243
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector{<:Any, L}) where {L, T}

test/cuda/cuda.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ end
4242
@testset "onehot gpu" begin
4343
y = Flux.onehotbatch(ones(3), 1:2) |> gpu;
4444
@test (repr("text/plain", y); true)
45+
46+
gA = rand(3, 2) |> gpu;
47+
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
4548
end
4649

4750
@testset "onecold gpu" begin

test/onehot.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232
b1 = Flux.OneHotVector(1, 3)
3333
b2 = Flux.OneHotVector(3, 5)
3434

35-
@test A*b1 == A[:,1]
35+
@test A * b1 == A[:,1]
3636
@test b1' * A == Array(b1') * A
3737
@test A' * b1 == A' * Array(b1)
3838
@test v' * b2 == v' * Array(b2)
@@ -41,6 +41,35 @@ end
4141
@test_throws DimensionMismatch A*b2
4242
end
4343

44+
@testset "AbstractMatrix-OneHotMatrix multiplication" begin
45+
A = [1 3 5; 2 4 6; 3 6 9]
46+
v = [1, 2, 3, 4, 5]
47+
X = reshape(v, (5, 1))
48+
b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3)
49+
b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
50+
b3 = Flux.OneHotMatrix([1, 1, 2], 4)
51+
b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :)
52+
b5 = reshape(b4, 6, :)
53+
b6 = reshape(Flux.OneHotMatrix([1 2 2; 2 2 1], 2), 3, :)
54+
b7 = reshape(Flux.OneHotMatrix([1 2 3; 1 2 3], 3), 6, :)
55+
56+
@test A * b1 == A[:,[1, 1, 2, 2]]
57+
@test b1' * A == Array(b1') * A
58+
@test A' * b1 == A' * Array(b1)
59+
@test A * b3' == A * Array(b3')
60+
@test transpose(X) * b2 == transpose(X) * Array(b2)
61+
@test A * b4 == A[:,[1, 2, 2, 2, 3, 1]]
62+
@test A * b5' == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
63+
@test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2])
64+
@test A * b7' == A[:,[1, 2, 3, 1, 2, 3]]
65+
66+
@test_throws DimensionMismatch A*b1'
67+
@test_throws DimensionMismatch A*b2
68+
@test_throws DimensionMismatch A*b2'
69+
@test_throws DimensionMismatch A*b6'
70+
@test_throws DimensionMismatch A*b7
71+
end
72+
4473
@testset "OneHotArray" begin
4574
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike
4675

0 commit comments

Comments
 (0)