Skip to content

Commit 6eaf1e1

Browse files
committed
updating triangular matrix multiplication to the new scheme
1 parent 6dc5694 commit 6eaf1e1

File tree

3 files changed

+95
-553
lines changed

3 files changed

+95
-553
lines changed

src/matrix_multiply.jl

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import LinearAlgebra: BlasFloat, matprod, mul!
77
const StaticMatMulLike{s1, s2, T} = Union{
88
StaticMatrix{s1, s2, T},
99
Symmetric{T, <:StaticMatrix{s1, s2, T}},
10-
Hermitian{T, <:StaticMatrix{s1, s2, T}}}
10+
Hermitian{T, <:StaticMatrix{s1, s2, T}},
11+
LowerTriangular{T, <:StaticMatrix{s1, s2, T}},
12+
UpperTriangular{T, <:StaticMatrix{s1, s2, T}}}
1113

1214
@inline *(A::StaticMatMulLike, B::AbstractVector) = _mul(Size(A), A, B)
1315
@inline *(A::StaticMatMulLike, B::StaticVector) = _mul(Size(A), Size(B), A, B)
@@ -39,48 +41,75 @@ function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, as
3941
end
4042
end
4143
end
42-
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type{<:StaticMatrix})
43-
return expr_gen(:any, :any)
44+
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :a)
45+
return expr_gen(:upper_triangular)
46+
end
47+
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
48+
return expr_gen(:lower_triangular)
4449
end
4550
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type)
46-
return gen_by_access(a) do access_a
47-
return quote
48-
return $(gen_by_access(b, :b) do access_b
49-
expr_gen(:any, access_b)
50-
end)
51-
end
51+
return quote
52+
return $(gen_by_access(b, :b) do access_b
53+
expr_gen(:any, access_b)
54+
end)
5255
end
5356
end
5457
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type)
55-
return gen_by_access(a) do access_a
56-
return quote
57-
if a.uplo == 'U'
58-
return $(gen_by_access(b, :b) do access_b
59-
expr_gen(:up, access_b)
60-
end)
61-
else
62-
return $(gen_by_access(b, :b) do access_b
63-
expr_gen(:lo, access_b)
64-
end)
65-
end
58+
return quote
59+
if a.uplo == 'U'
60+
return $(gen_by_access(b, :b) do access_b
61+
expr_gen(:up, access_b)
62+
end)
63+
else
64+
return $(gen_by_access(b, :b) do access_b
65+
expr_gen(:lo, access_b)
66+
end)
6667
end
6768
end
6869
end
6970
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type)
70-
return gen_by_access(a) do access_a
71-
return quote
72-
if a.uplo == 'U'
73-
return $(gen_by_access(b, :b) do access_b
74-
expr_gen(:up_herm, access_b)
75-
end)
76-
else
77-
return $(gen_by_access(b, :b) do access_b
78-
expr_gen(:lo_herm, access_b)
79-
end)
80-
end
71+
return quote
72+
if a.uplo == 'U'
73+
return $(gen_by_access(b, :b) do access_b
74+
expr_gen(:up_herm, access_b)
75+
end)
76+
else
77+
return $(gen_by_access(b, :b) do access_b
78+
expr_gen(:lo_herm, access_b)
79+
end)
8180
end
8281
end
8382
end
83+
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
84+
return quote
85+
return $(gen_by_access(b, :b) do access_b
86+
expr_gen(:upper_triangular, access_b)
87+
end)
88+
end
89+
end
90+
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
91+
return quote
92+
return $(gen_by_access(b, :b) do access_b
93+
expr_gen(:lower_triangular, access_b)
94+
end)
95+
end
96+
end
97+
98+
"""
99+
mul_result_structure(a::Type, b::Type)
100+
101+
Get a structure wrapper that should be applied to the result of multiplication of matrices
102+
of given types (a*b).
103+
"""
104+
function mul_result_structure(a, b)
105+
return identity
106+
end
107+
function mul_result_structure(::UpperTriangular{<:Any, <:StaticMatrix}, ::UpperTriangular{<:Any, <:StaticMatrix})
108+
return UpperTriangular
109+
end
110+
function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerTriangular{<:Any, <:StaticMatrix})
111+
return LowerTriangular
112+
end
84113

85114
function uplo_access(sa, asym, k, j, uplo)
86115
if uplo == :any
@@ -92,7 +121,7 @@ function uplo_access(sa, asym, k, j, uplo)
92121
return :($asym[$(LinearIndices(sa)[j, k])])
93122
end
94123
elseif uplo == :lo
95-
if j <= k
124+
if k >= j
96125
return :($asym[$(LinearIndices(sa)[k, j])])
97126
else
98127
return :($asym[$(LinearIndices(sa)[j, k])])
@@ -104,11 +133,23 @@ function uplo_access(sa, asym, k, j, uplo)
104133
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
105134
end
106135
elseif uplo == :lo_herm
107-
if j <= k
136+
if k >= j
108137
return :($asym[$(LinearIndices(sa)[k, j])])
109138
else
110139
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
111140
end
141+
elseif uplo == :upper_triangular
142+
if k <= j
143+
return :($asym[$(LinearIndices(sa)[k, j])])
144+
else
145+
return :(zero(T))
146+
end
147+
elseif uplo == :lower_triangular
148+
if k >= j
149+
return :($asym[$(LinearIndices(sa)[k, j])])
150+
else
151+
return :(zero(T))
152+
end
112153
end
113154
end
114155

@@ -147,7 +188,7 @@ end
147188
if sa[2] != 0
148189
retexpr = gen_by_access(a) do access_a
149190
exprs = mul_smat_vec_exprs(sa, access_a)
150-
return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
191+
return :(@inbounds similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
151192
end
152193
else
153194
exprs = [:(zero(T)) for k = 1:sa[1]]
@@ -195,39 +236,6 @@ end
195236
end
196237
end
197238

198-
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}, b::Union{SizedMatrix{T}, MMatrix{T}, MArray{T}}) where {sa, sb, T <: BlasFloat}
199-
S = Size(sa[1], sb[2])
200-
201-
# Heuristic choice between BLAS and explicit unrolling (or chunk-based unrolling)
202-
if sa[1]*sa[2]*sb[2] >= 14*14*14
203-
Sa = TSize{size(S),false}()
204-
Sb = TSize{sa,false}()
205-
Sc = TSize{sb,false}()
206-
_add = MulAddMul(true,false)
207-
return quote
208-
@_inline_meta
209-
C = similar(a, T, $S)
210-
mul_blas!($Sa, C, $Sa, $Sb, a, b, $_add)
211-
return C
212-
end
213-
elseif sa[1]*sa[2]*sb[2] < 8*8*8
214-
return quote
215-
@_inline_meta
216-
return mul_unrolled(Sa, Sb, a, b)
217-
end
218-
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
219-
return quote
220-
@_inline_meta
221-
return similar_type(a, T, $S)(mul_unrolled_chunks(Sa, Sb, a, b))
222-
end
223-
else
224-
return quote
225-
@_inline_meta
226-
return mul_loop(Sa, Sb, a, b)
227-
end
228-
end
229-
end
230-
231239
@generated function mul_unrolled(::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
232240
if sb[1] != sa[2]
233241
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
@@ -240,17 +248,17 @@ end
240248
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)),
241249
[:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
242250
) for k1 = 1:sa[1], k2 = 1:sb[2]]
243-
return :(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
251+
return :((mul_result_structure(a, b))(similar_type(a, T, $S)(tuple($(exprs...)))))
244252
end
245253
else
246254
exprs = [:(zero(T)) for k1 = 1:sa[1], k2 = 1:sb[2]]
247-
retexpr = :(@inbounds return similar_type(a, T, $S)(tuple($(exprs...))))
255+
retexpr = :(return (mul_result_structure(a, b))(similar_type(a, T, $S)(tuple($(exprs...)))))
248256
end
249257

250258
return quote
251259
@_inline_meta
252260
T = promote_op(matprod,Ta,Tb)
253-
$retexpr
261+
@inbounds $retexpr
254262
end
255263
end
256264

0 commit comments

Comments
 (0)