Skip to content

Commit e6667bf

Browse files
committed
adding Diagonal to the new matrix multiplication scheme
1 parent f295ba6 commit e6667bf

File tree

6 files changed

+48
-13
lines changed

6 files changed

+48
-13
lines changed

benchmark/bench_mat_mul.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ mul_wrappers = [
1414
(m -> UnitUpperTriangular(m), "uup-tri"),
1515
(m -> UnitLowerTriangular(m), "ulo-tri"),
1616
(m -> Adjoint(m), "adjoint"),
17-
(m -> Transpose(m), "transpo")]
17+
(m -> Transpose(m), "transpo"),
18+
(m -> Diagonal(m), "diag ")]
1819

1920
for N in [2, 4, 8, 10, 16]
2021

src/SDiagonal.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ size(::Type{<:SDiagonal{N}}) where {N} = (N,N)
1818
size(::Type{<:SDiagonal{N}}, d::Int) where {N} = d > 2 ? 1 : N
1919

2020
# define specific methods to avoid allocating mutable arrays
21-
*(A::StaticMatrix, D::SDiagonal) = A .* transpose(D.diag)
22-
*(D::SDiagonal, A::StaticMatrix) = D.diag .* A
2321
\(D::SDiagonal, b::AbstractVector) = D.diag .\ b
2422
\(D::SDiagonal, b::StaticVector) = D.diag .\ b # catch ambiguity
2523

src/matrix_multiply.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ end
6666
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a)
6767
return expr_gen(:adjoint)
6868
end
69+
function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a)
70+
return expr_gen(:diagonal)
71+
end
6972
"""
7073
gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
7174
@@ -148,6 +151,13 @@ function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::T
148151
end)
149152
end
150153
end
154+
function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type)
155+
return quote
156+
return $(gen_by_access(b, :wrapped_b) do access_b
157+
expr_gen(:diagonal, access_b)
158+
end)
159+
end
160+
end
151161

152162
"""
153163
mul_result_structure(a::Type, b::Type)
@@ -164,6 +174,21 @@ end
164174
function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerTriangular{<:Any, <:StaticMatrix})
165175
return LowerTriangular
166176
end
177+
function mul_result_structure(::UpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal)
178+
return UpperTriangular
179+
end
180+
function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal)
181+
return LowerTriangular
182+
end
183+
function mul_result_structure(::SDiagonal, ::UpperTriangular{<:Any, <:StaticMatrix})
184+
return UpperTriangular
185+
end
186+
function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix})
187+
return LowerTriangular
188+
end
189+
function mul_result_structure(::SDiagonal, ::SDiagonal)
190+
return Diagonal
191+
end
167192

168193
"""
169194
uplo_access(sa, asym, k, j, uplo)
@@ -247,6 +272,12 @@ function uplo_access(sa, asym, k, j, uplo)
247272
return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])]))
248273
elseif uplo == :adjoint
249274
return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])]))
275+
elseif uplo == :diagonal
276+
if k == j
277+
return :($asym[$k])
278+
else
279+
return :(zero($TAsym))
280+
end
250281
else
251282
error("Unknown uplo: $uplo")
252283
end
@@ -347,12 +378,12 @@ end
347378

348379
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
349380
# Heuristic choice for amount of codegen
350-
if sa[1]*sa[2]*sb[2] <= 8*8*8 || !(a <: StaticMatrix) || !(b <: StaticMatrix)
381+
if sa[1]*sa[2]*sb[2] <= 8*8*8 || a <: Diagonal || b <: Diagonal
351382
return quote
352383
@_inline_meta
353384
return mul_unrolled(Sa, Sb, a, b)
354385
end
355-
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
386+
elseif (sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14) || !(a <: StaticMatrix) || !(b <: StaticMatrix)
356387
return quote
357388
@_inline_meta
358389
return mul_unrolled_chunks(Sa, Sb, a, b)
@@ -436,7 +467,7 @@ end
436467
tmp_type_out = :(SVector{$(sa[1]), T})
437468

438469
retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
439-
vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()),
470+
vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()),
440471
a, $(Expr(:call, tmp_type_in, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))::$tmp_type_out) for k2 = 1:sb[2]]
441472

442473
exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]]

src/matrix_multiply_add.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ const StaticMatMulLike{s1, s2, T} = Union{
2727
UnitLowerTriangular{T, <:StaticMatrix{s1, s2, T}},
2828
UnitUpperTriangular{T, <:StaticMatrix{s1, s2, T}},
2929
Adjoint{T, <:StaticMatrix{s1, s2, T}},
30-
Transpose{T, <:StaticMatrix{s1, s2, T}}}
30+
Transpose{T, <:StaticMatrix{s1, s2, T}},
31+
SDiagonal{s1, T}}
3132

3233

3334
""" Size that stores whether a Matrix is a Transpose
@@ -188,7 +189,7 @@ end
188189
can_blas = Tc == Ta && Tc == Tb && Tc <: BlasFloat && a <: Union{StaticMatrix,Transpose} && b <: Union{StaticMatrix,Transpose}
189190

190191
mult_dim = multiplied_dimension(a,b)
191-
if mult_dim < 4*4*4
192+
if mult_dim < 4*4*4 || a <: Diagonal || b <: Diagonal
192193
return quote
193194
@_inline_meta
194195
muladd_unrolled_all!(Sc, c, Sa, Sb, a, b, _add)

test/matrix_multiply.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ mul_wrappers = [
1111
m -> UnitUpperTriangular(m),
1212
m -> UnitLowerTriangular(m),
1313
m -> Adjoint(m),
14-
m -> Transpose(m)]
14+
m -> Transpose(m),
15+
m -> Diagonal(m)]
1516

1617
@testset "Matrix multiplication" begin
1718
@testset "Matrix-vector" begin
@@ -172,6 +173,8 @@ mul_wrappers = [
172173
LowerTriangular{Int,typeof(mm)}
173174
elseif res_structure == UpperTriangular
174175
UpperTriangular{Int,typeof(mm)}
176+
elseif res_structure == Diagonal
177+
Diagonal{Int,<:SVector}
175178
else
176179
error("Unknown structure: ", res_structure)
177180
end

test/matrix_multiply_add.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ mul_add_wrappers = [
2222
m -> LowerTriangular(m),
2323
m -> UnitUpperTriangular(m),
2424
m -> UnitLowerTriangular(m),
25-
m -> adjoint(m),
26-
m -> transpose(m)]
25+
m -> Adjoint(m),
26+
m -> Transpose(m),
27+
m -> Diagonal(m)]
2728

2829

2930
# check_dims
@@ -224,15 +225,15 @@ function test_wrappers_for_size(N, test_block)
224225
# matrix-vector
225226
for wrapper in mul_add_wrappers
226227
# LinearAlgebra can't handle these
227-
if all(T -> !isa(wrapper([1 2; 3 4]), T), [Symmetric, Hermitian])
228+
if all(T -> !isa(wrapper([1 2; 3 4]), T), [Symmetric, Hermitian, Diagonal])
228229
mul!(Cv_block, wrapper(A_block), bv_block)
229230
@test Cv_block == wrapper(Array(A_block))*Array(bv_block)
230231
end
231232
end
232233

233234
# matrix-matrix
234235
for wrapper_a in mul_add_wrappers, wrapper_b in mul_add_wrappers
235-
if all(T -> !isa(wrapper_a([1 2; 3 4]), T) && !isa(wrapper_b([1 2; 3 4]), T), [Symmetric, Hermitian])
236+
if all(T -> !isa(wrapper_a([1 2; 3 4]), T) && !isa(wrapper_b([1 2; 3 4]), T), [Symmetric, Hermitian, Diagonal])
236237
mul!(C_block, wrapper_a(A_block), wrapper_b(B_block))
237238
@test C_block == wrapper_a(Array(A_block))*wrapper_b(Array(B_block))
238239
end

0 commit comments

Comments
 (0)