Skip to content

Commit f39afe0

Browse files
committed
fixed code for symmetric and hermitian multiplication and small cleanup
1 parent 246c256 commit f39afe0

File tree

3 files changed

+82
-78
lines changed

3 files changed

+82
-78
lines changed

src/matrix_multiply.jl

Lines changed: 60 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,10 @@ import LinearAlgebra: BlasFloat, matprod, mul!
44
# Manage dispatch of * and mul!
55
# TODO Adjoint? (Inner product?)
66

7-
"""
8-
StaticMatMulLike
9-
10-
Static wrappers used for multiplication dispatch.
11-
"""
12-
const StaticMatMulLike{s1, s2, T} = Union{
13-
StaticMatrix{s1, s2, T},
14-
Symmetric{T, <:StaticMatrix{s1, s2, T}},
15-
Hermitian{T, <:StaticMatrix{s1, s2, T}},
16-
LowerTriangular{T, <:StaticMatrix{s1, s2, T}},
17-
UpperTriangular{T, <:StaticMatrix{s1, s2, T}},
18-
UnitLowerTriangular{T, <:StaticMatrix{s1, s2, T}},
19-
UnitUpperTriangular{T, <:StaticMatrix{s1, s2, T}},
20-
UpperHessenberg{T, <:StaticMatrix{s1, s2, T}},
21-
Adjoint{T, <:StaticMatrix{s1, s2, T}},
22-
Transpose{T, <:StaticMatrix{s1, s2, T}}}
23-
24-
25-
@inline *(A::StaticMatMulLike, B::AbstractVector) = _mul(Size(A), A, B)
7+
# *(A::StaticMatMulLike, B::AbstractVector) causes an ambiguity with SparseArrays
8+
@inline *(A::StaticMatrix, B::AbstractVector) = _mul(Size(A), A, B)
269
@inline *(A::StaticMatMulLike, B::StaticVector) = _mul(Size(A), Size(B), A, B)
10+
@inline *(A::StaticMatrix, B::StaticVector) = _mul(Size(A), Size(B), A, B)
2711
@inline *(A::StaticMatMulLike, B::StaticMatMulLike) = _mul(Size(A), Size(B), A, B)
2812
@inline *(A::StaticVector, B::StaticMatMulLike) = *(reshape(A, Size(Size(A)[1], 1)), B)
2913
@inline *(A::StaticVector, B::Transpose{<:Any, <:StaticVector}) = _mul(Size(A), Size(B), A, B)
@@ -32,7 +16,7 @@ const StaticMatMulLike{s1, s2, T} = Union{
3216
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
3317

3418
"""
35-
gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a)
19+
gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a)
3620
3721
Statically generate outer code for fully unrolled multiplication loops.
3822
Returned code does wrapper-specific tests (for example if a symmetric matrix view is
@@ -43,10 +27,10 @@ element access.
4327
4428
The name of the matrix to test is indicated by `asym`.
4529
"""
46-
function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :a)
30+
function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :wrapped_a)
4731
return expr_gen(:any)
4832
end
49-
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :a)
33+
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
5034
return quote
5135
if $(asym).uplo == 'U'
5236
$(expr_gen(:up))
@@ -55,7 +39,7 @@ function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, as
5539
end
5640
end
5741
end
58-
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :a)
42+
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
5943
return quote
6044
if $(asym).uplo == 'U'
6145
$(expr_gen(:up_herm))
@@ -64,25 +48,22 @@ function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, as
6448
end
6549
end
6650
end
67-
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :a)
51+
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
6852
return expr_gen(:upper_triangular)
6953
end
70-
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
54+
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
7155
return expr_gen(:lower_triangular)
7256
end
73-
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :a)
57+
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
7458
return expr_gen(:unit_upper_triangular)
7559
end
76-
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
60+
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
7761
return expr_gen(:unit_lower_triangular)
7862
end
79-
function gen_by_access(expr_gen, a::Type{<:UpperHessenberg{<:Any, <:StaticMatrix}}, asym = :a)
80-
return expr_gen(:upper_hessenberg)
81-
end
82-
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :a)
63+
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a)
8364
return expr_gen(:transpose)
8465
end
85-
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :a)
66+
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a)
8667
return expr_gen(:adjoint)
8768
end
8869
"""
@@ -94,82 +75,75 @@ first for matrix `a` and the second for matrix `b`.
9475
"""
9576
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type)
9677
return quote
97-
return $(gen_by_access(b, :b) do access_b
78+
return $(gen_by_access(b, :wrapped_b) do access_b
9879
expr_gen(:any, access_b)
9980
end)
10081
end
10182
end
10283
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type)
10384
return quote
104-
if a.uplo == 'U'
105-
return $(gen_by_access(b, :b) do access_b
85+
if wrapped_a.uplo == 'U'
86+
return $(gen_by_access(b, :wrapped_b) do access_b
10687
expr_gen(:up, access_b)
10788
end)
10889
else
109-
return $(gen_by_access(b, :b) do access_b
90+
return $(gen_by_access(b, :wrapped_b) do access_b
11091
expr_gen(:lo, access_b)
11192
end)
11293
end
11394
end
11495
end
11596
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type)
11697
return quote
117-
if a.uplo == 'U'
118-
return $(gen_by_access(b, :b) do access_b
98+
if wrapped_a.uplo == 'U'
99+
return $(gen_by_access(b, :wrapped_b) do access_b
119100
expr_gen(:up_herm, access_b)
120101
end)
121102
else
122-
return $(gen_by_access(b, :b) do access_b
103+
return $(gen_by_access(b, :wrapped_b) do access_b
123104
expr_gen(:lo_herm, access_b)
124105
end)
125106
end
126107
end
127108
end
128109
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
129110
return quote
130-
return $(gen_by_access(b, :b) do access_b
111+
return $(gen_by_access(b, :wrapped_b) do access_b
131112
expr_gen(:upper_triangular, access_b)
132113
end)
133114
end
134115
end
135116
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
136117
return quote
137-
return $(gen_by_access(b, :b) do access_b
118+
return $(gen_by_access(b, :wrapped_b) do access_b
138119
expr_gen(:lower_triangular, access_b)
139120
end)
140121
end
141122
end
142123
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
143124
return quote
144-
return $(gen_by_access(b, :b) do access_b
125+
return $(gen_by_access(b, :wrapped_b) do access_b
145126
expr_gen(:unit_upper_triangular, access_b)
146127
end)
147128
end
148129
end
149130
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
150131
return quote
151-
return $(gen_by_access(b, :b) do access_b
132+
return $(gen_by_access(b, :wrapped_b) do access_b
152133
expr_gen(:unit_lower_triangular, access_b)
153134
end)
154135
end
155136
end
156-
function gen_by_access(expr_gen, a::Type{<:UpperHessenberg{<:Any, <:StaticMatrix}}, b::Type)
157-
return quote
158-
return $(gen_by_access(b, :b) do access_b
159-
expr_gen(:upper_hessenberg, access_b)
160-
end)
161-
end
162-
end
163137
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type)
164138
return quote
165-
return $(gen_by_access(b, :b) do access_b
139+
return $(gen_by_access(b, :wrapped_b) do access_b
166140
expr_gen(:transpose, access_b)
167141
end)
168142
end
169143
end
170144
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type)
171145
return quote
172-
return $(gen_by_access(b, :b) do access_b
146+
return $(gen_by_access(b, :wrapped_b) do access_b
173147
expr_gen(:adjoint, access_b)
174148
end)
175149
end
@@ -200,65 +174,74 @@ statically known for this function to work. `uplo` is the access pattern mode ge
200174
by the `gen_by_access` function.
201175
"""
202176
function uplo_access(sa, asym, k, j, uplo)
177+
TAsym = Symbol("T"*string(asym))
203178
if uplo == :any
204179
return :($asym[$(LinearIndices(sa)[k, j])])
205180
elseif uplo == :up
206-
if k <= j
181+
if k < j
207182
return :($asym[$(LinearIndices(sa)[k, j])])
183+
elseif k == j
184+
return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :U))
208185
else
209-
return :($asym[$(LinearIndices(sa)[j, k])])
186+
return :(transpose($asym[$(LinearIndices(sa)[j, k])]))
210187
end
211188
elseif uplo == :lo
212-
if k >= j
189+
if k > j
213190
return :($asym[$(LinearIndices(sa)[k, j])])
191+
elseif k == j
192+
return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :L))
214193
else
215-
return :($asym[$(LinearIndices(sa)[j, k])])
194+
return :(transpose($asym[$(LinearIndices(sa)[j, k])]))
216195
end
217196
elseif uplo == :up_herm
218-
if k <= j
197+
if k < j
219198
return :($asym[$(LinearIndices(sa)[k, j])])
199+
elseif k == j
200+
return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :U))
220201
else
221202
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
222203
end
223204
elseif uplo == :lo_herm
224-
if k >= j
205+
if k > j
225206
return :($asym[$(LinearIndices(sa)[k, j])])
207+
elseif k == j
208+
return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :L))
226209
else
227210
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
228211
end
229212
elseif uplo == :upper_triangular
230213
if k <= j
231214
return :($asym[$(LinearIndices(sa)[k, j])])
232215
else
233-
return :(zero(T))
216+
return :(zero($TAsym))
234217
end
235218
elseif uplo == :lower_triangular
236219
if k >= j
237220
return :($asym[$(LinearIndices(sa)[k, j])])
238221
else
239-
return :(zero(T))
222+
return :(zero($TAsym))
240223
end
241224
elseif uplo == :unit_upper_triangular
242225
if k < j
243226
return :($asym[$(LinearIndices(sa)[k, j])])
244227
elseif k == j
245-
return :(oneunit(T))
228+
return :(oneunit($TAsym))
246229
else
247-
return :(zero(T))
230+
return :(zero($TAsym))
248231
end
249232
elseif uplo == :unit_lower_triangular
250233
if k > j
251234
return :($asym[$(LinearIndices(sa)[k, j])])
252235
elseif k == j
253-
return :(oneunit(T))
236+
return :(oneunit($TAsym))
254237
else
255-
return :(zero(T))
238+
return :(zero($TAsym))
256239
end
257240
elseif uplo == :upper_hessenberg
258241
if k <= j+1
259242
return :($asym[$(LinearIndices(sa)[k, j])])
260243
else
261-
return :(zero(T))
244+
return :(zero($TAsym))
262245
end
263246
elseif uplo == :transpose
264247
return :($asym[$(LinearIndices(reverse(sa))[j, k])])
@@ -273,9 +256,9 @@ function mul_smat_vec_exprs(sa, access_a)
273256
return [reduce((ex1,ex2) -> :(+($ex1,$ex2)), [:($(uplo_access(sa, :a, k, j, access_a))*b[$j]) for j = 1:sa[2]]) for k = 1:sa[1]]
274257
end
275258

276-
@generated function _mul(::Size{sa}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
259+
@generated function _mul(::Size{sa}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::AbstractVector{Tb}) where {sa, Ta, Tb}
277260
if sa[2] != 0
278-
retexpr = gen_by_access(a) do access_a
261+
retexpr = gen_by_access(wrapped_a) do access_a
279262
exprs = mul_smat_vec_exprs(sa, access_a)
280263
return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
281264
end
@@ -290,17 +273,18 @@ end
290273
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $(size(b))"))
291274
end
292275
T = promote_op(matprod,Ta,Tb)
276+
a = mul_parent(wrapped_a)
293277
$retexpr
294278
end
295279
end
296280

297-
@generated function _mul(::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
281+
@generated function _mul(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
298282
if sb[1] != sa[2]
299283
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
300284
end
301285

302286
if sa[2] != 0
303-
retexpr = gen_by_access(a) do access_a
287+
retexpr = gen_by_access(wrapped_a) do access_a
304288
exprs = mul_smat_vec_exprs(sa, access_a)
305289
return :(@inbounds similar_type(b, T, Size(sa[1]))(tuple($(exprs...))))
306290
end
@@ -312,6 +296,7 @@ end
312296
return quote
313297
@_inline_meta
314298
T = promote_op(matprod,Ta,Tb)
299+
a = mul_parent(wrapped_a)
315300
$retexpr
316301
end
317302
end
@@ -362,28 +347,30 @@ end
362347
end
363348
end
364349

365-
@generated function mul_unrolled(::Size{sa}, ::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
350+
@generated function mul_unrolled(::Size{sa}, ::Size{sb}, wrapped_a::StaticMatMulLike{<:Any, <:Any, Ta}, wrapped_b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
366351
if sb[1] != sa[2]
367352
throw(DimensionMismatch("Tried to multiply arrays of size $sa and $sb"))
368353
end
369354

370355
S = Size(sa[1], sb[2])
371356

372357
if sa[2] != 0
373-
retexpr = gen_by_access(a, b) do access_a, access_b
358+
retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b
374359
exprs = [reduce((ex1,ex2) -> :(+($ex1,$ex2)),
375360
[:($(uplo_access(sa, :a, k1, j, access_a))*$(uplo_access(sb, :b, j, k2, access_b))) for j = 1:sa[2]]
376361
) for k1 = 1:sa[1], k2 = 1:sb[2]]
377-
return :((mul_result_structure(a, b))(similar_type(a, T, $S)(tuple($(exprs...)))))
362+
return :((mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...)))))
378363
end
379364
else
380365
exprs = [:(zero(T)) for k1 = 1:sa[1], k2 = 1:sb[2]]
381-
retexpr = :(return (mul_result_structure(a, b))(similar_type(a, T, $S)(tuple($(exprs...)))))
366+
retexpr = :(return (mul_result_structure(wrapped_a, wrapped_b))(similar_type(a, T, $S)(tuple($(exprs...)))))
382367
end
383368

384369
return quote
385370
@_inline_meta
386371
T = promote_op(matprod,Ta,Tb)
372+
a = mul_parent(wrapped_a)
373+
b = mul_parent(wrapped_b)
387374
@inbounds $retexpr
388375
end
389376
end

src/matrix_multiply_add.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,23 @@ struct NoMulAdd{T} <: MulAddMul{T} end
1717
@inline alpha(ma::NoMulAdd{T}) where T = one(T)
1818
@inline beta(ma::NoMulAdd{T}) where T = zero(T)
1919

20+
"""
21+
StaticMatMulLike
22+
23+
Static wrappers used for multiplication dispatch.
24+
"""
25+
const StaticMatMulLike{s1, s2, T} = Union{
26+
StaticMatrix{s1, s2, T},
27+
Symmetric{T, <:StaticMatrix{s1, s2, T}},
28+
Hermitian{T, <:StaticMatrix{s1, s2, T}},
29+
LowerTriangular{T, <:StaticMatrix{s1, s2, T}},
30+
UpperTriangular{T, <:StaticMatrix{s1, s2, T}},
31+
UnitLowerTriangular{T, <:StaticMatrix{s1, s2, T}},
32+
UnitUpperTriangular{T, <:StaticMatrix{s1, s2, T}},
33+
Adjoint{T, <:StaticMatrix{s1, s2, T}},
34+
Transpose{T, <:StaticMatrix{s1, s2, T}}}
35+
36+
2037
""" Size that stores whether a Matrix is a Transpose
2138
Useful when selecting multiplication methods, and avoiding allocations when dealing with
2239
the `Transpose` type by passing around the original matrix.
@@ -40,8 +57,8 @@ Base.transpose(::TSize{S,:any}) where {S,T} = TSize{reverse(S),:transpose}()
4057

4158
# Get the parent of transposed arrays, or the array itself if it has no parent
4259
# Different from Base.parent because we only want to get rid of Transpose and Adjoint
43-
mul_parent(A::Union{<:Transpose{<:Any,<:StaticArray}, <:Adjoint{<:Any,<:StaticArray}}) = A.parent
44-
mul_parent(A::StaticArray) = A
60+
@inline mul_parent(A::Union{StaticMatMulLike, Adjoint{<:Any,<:StaticVector}, Transpose{<:Any,<:StaticVector}}) = Base.parent(A)
61+
@inline mul_parent(A::StaticArray) = A
4562

4663
# 5-argument matrix multiplication
4764
# To avoid allocations, strip away Transpose type and store tranpose info in Size

0 commit comments

Comments
 (0)