Skip to content

Commit e725e4d

Browse files
committed
adjoint and transpose wrappers for multiplication; more documentation
1 parent 56c9ca5 commit e725e4d

File tree

3 files changed

+90
-11
lines changed

3 files changed

+90
-11
lines changed

src/matrix_multiply.jl

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@ import LinearAlgebra: BlasFloat, matprod, mul!
44
# Manage dispatch of * and mul!
55
# TODO Adjoint? (Inner product?)
66

7+
"""
8+
StaticMatMulLike
9+
10+
Static wrappers used for multiplication dispatch.
11+
"""
712
const StaticMatMulLike{s1, s2, T} = Union{
813
StaticMatrix{s1, s2, T},
914
Symmetric{T, <:StaticMatrix{s1, s2, T}},
1015
Hermitian{T, <:StaticMatrix{s1, s2, T}},
1116
LowerTriangular{T, <:StaticMatrix{s1, s2, T}},
12-
UpperTriangular{T, <:StaticMatrix{s1, s2, T}}}
17+
UpperTriangular{T, <:StaticMatrix{s1, s2, T}},
18+
Adjoint{T, <:StaticMatrix{s1, s2, T}},
19+
Transpose{T, <:StaticMatrix{s1, s2, T}}}
20+
1321

1422
@inline *(A::StaticMatMulLike, B::AbstractVector) = _mul(Size(A), A, B)
1523
@inline *(A::StaticMatMulLike, B::StaticVector) = _mul(Size(A), Size(B), A, B)
@@ -20,6 +28,18 @@ const StaticMatMulLike{s1, s2, T} = Union{
2028
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
2129
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
2230

31+
"""
32+
gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :a)
33+
34+
Statically generate outer code for fully unrolled multiplication loops.
35+
Returned code does wrapper-specific tests (for example if a symmetric matrix view is
36+
`U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
37+
The function `expr_gen` receives access pattern description symbol as its argument
38+
and this symbol is then consumed by uplo_access to generate the right code for matrix
39+
element access.
40+
41+
The name of the matrix to test is indicated by `asym`.
42+
"""
2343
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, asym = :a)
2444
return expr_gen(:any)
2545
end
@@ -47,6 +67,19 @@ end
4767
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
4868
return expr_gen(:lower_triangular)
4969
end
70+
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, asym = :a)
71+
return expr_gen(:transpose)
72+
end
73+
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, asym = :a)
74+
return expr_gen(:adjoint)
75+
end
76+
"""
77+
gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
78+
79+
Simiar to gen_by_access with only one type argument. The difference is that tests for both
80+
arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
81+
first for matrix `a` and the second for matrix `b`.
82+
"""
5083
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type)
5184
return quote
5285
return $(gen_by_access(b, :b) do access_b
@@ -94,6 +127,20 @@ function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix
94127
end)
95128
end
96129
end
130+
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type)
131+
return quote
132+
return $(gen_by_access(b, :b) do access_b
133+
expr_gen(:transpose, access_b)
134+
end)
135+
end
136+
end
137+
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type)
138+
return quote
139+
return $(gen_by_access(b, :b) do access_b
140+
expr_gen(:adjoint, access_b)
141+
end)
142+
end
143+
end
97144

98145
"""
99146
mul_result_structure(a::Type, b::Type)
@@ -111,6 +158,14 @@ function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerT
111158
return LowerTriangular
112159
end
113160

161+
"""
162+
uplo_access(sa, asym, k, j, uplo)
163+
164+
Generate code for matrix element access, for a matrix of size `sa` locally referred to
165+
as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
166+
statically known for this function to work. `uplo` is the access pattern mode generated
167+
by the `gen_by_access` function.
168+
"""
114169
function uplo_access(sa, asym, k, j, uplo)
115170
if uplo == :any
116171
return :($asym[$(LinearIndices(sa)[k, j])])
@@ -150,6 +205,10 @@ function uplo_access(sa, asym, k, j, uplo)
150205
else
151206
return :(zero(T))
152207
end
208+
elseif uplo == :transpose
209+
return :($asym[$(LinearIndices(sa)[j, k])])
210+
elseif uplo == :ajoint
211+
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
153212
end
154213
end
155214

@@ -216,23 +275,35 @@ end
216275
end
217276
end
218277

278+
_unstatic_array(::Type{TSA}) where {S, T, N, TSA<:StaticArray{S,T,N}} = AbstractArray{T,N}
279+
for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular]
280+
@eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}}
281+
end
282+
219283
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
220284
# Heuristic choice for amount of codegen
221285
if sa[1]*sa[2]*sb[2] <= 8*8*8 || !(a <: StaticMatrix) || !(b <: StaticMatrix)
222286
return quote
223287
@_inline_meta
224288
return mul_unrolled(Sa, Sb, a, b)
225289
end
226-
elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
290+
elseif a <: StaticMatrix && b <:StaticMatrix && sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14
227291
return quote
228292
@_inline_meta
229293
return mul_unrolled_chunks(Sa, Sb, a, b)
230294
end
231-
else
295+
elseif a <: StaticMatrix && b <:StaticMatrix
232296
return quote
233297
@_inline_meta
234298
return mul_loop(Sa, Sb, a, b)
235299
end
300+
else
301+
# we don't have any special code for handling this case so let's fall back to
302+
# the generic implementation of matrix multiplication
303+
return quote
304+
@_inline_meta
305+
return invoke(*, Tuple{$(_unstatic_array(a)),$(_unstatic_array(b))}, a, b)
306+
end
236307
end
237308
end
238309

src/triangular.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
LinearAlgebra.LowerTriangular(transpose(A.data))
77
@inline adjoint(A::LinearAlgebra.UpperTriangular{<:Any,<:StaticMatrix}) =
88
LinearAlgebra.LowerTriangular(adjoint(A.data))
9-
@inline Base.:*(A::Adjoint{<:Any,<:StaticVecOrMat}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
9+
@inline Base.:*(A::Adjoint{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
1010
adjoint(adjoint(B) * adjoint(A))
11-
@inline Base.:*(A::Transpose{<:Any,<:StaticVecOrMat}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
11+
@inline Base.:*(A::Transpose{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
1212
transpose(transpose(B) * transpose(A))
13-
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Adjoint{<:Any,<:StaticVecOrMat}) =
13+
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Adjoint{<:Any,<:StaticVector}) =
1414
adjoint(adjoint(B) * adjoint(A))
15-
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVecOrMat}) =
15+
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVector}) =
1616
transpose(transpose(B) * transpose(A))
1717

1818
const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}

test/matrix_multiply.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ mul_wrappers = [
77
m -> Hermitian(m, :U),
88
m -> Hermitian(m, :L),
99
m -> UpperTriangular(m),
10-
m -> LowerTriangular(m)]
10+
m -> LowerTriangular(m),
11+
m -> adjoint(m),
12+
m -> transpose(m)]
1113

1214
@testset "Matrix multiplication" begin
1315
@testset "Matrix-vector" begin
@@ -149,13 +151,19 @@ mul_wrappers = [
149151
# check different sizes because there are multiple implementations for matrices of different sizes
150152
for (mm, nn) in [
151153
(m, n),
152-
#(SMatrix{10, 10}(collect(1:100)), SMatrix{10, 10}(collect(1:100))),
153-
(SMatrix{15, 15}(collect(1:225)), SMatrix{15, 15}(collect(1:225)))]
154+
(SMatrix{10, 10}(collect(1:100)), SMatrix{10, 10}(collect(1:100))),
155+
(SMatrix{15, 15}(collect(1:225)), SMatrix{15, 15}(collect(1:225)))
156+
]
154157
for wrapper_m in mul_wrappers, wrapper_n in mul_wrappers
155158
wm = wrapper_m(mm)
156159
wn = wrapper_n(nn)
160+
if length(mm) >= 100 && (!isa(wm, StaticArray) || !isa(wn, StaticArray))
161+
continue
162+
end
157163
res_structure = StaticArrays.mul_result_structure(wm, wn)
158-
expected_type = if res_structure == identity
164+
expected_type = if length(m) >= 100
165+
Matrix{Int}
166+
elseif res_structure == identity
159167
typeof(mm)
160168
elseif res_structure == LowerTriangular
161169
LowerTriangular{Int,typeof(mm)}

0 commit comments

Comments
 (0)