Skip to content

Commit 301e082

Browse files
authored
Change order of function definition in matrix multiplication code (#833)
* rearrange order of function definitions in matrix multiplication This should safeguard against generated functions being generated earlier than method definitions they need. * moving gen_by_access as well
1 parent fe8e4fd commit 301e082

File tree

2 files changed

+260
-262
lines changed

2 files changed

+260
-262
lines changed

src/matrix_multiply.jl

Lines changed: 0 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -15,150 +15,6 @@ import LinearAlgebra: BlasFloat, matprod, mul!
1515
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Adjoint{<:Any,<:StaticVector}) where {N} = vec(A) * B
1616
@inline *(A::StaticArray{Tuple{N,1},<:Any,2}, B::Transpose{<:Any,<:StaticVector}) where {N} = vec(A) * B
1717

18-
"""
19-
gen_by_access(expr_gen, a::Type{<:AbstractArray}, asym = :wrapped_a)
20-
21-
Statically generate outer code for fully unrolled multiplication loops.
22-
Returned code does wrapper-specific tests (for example if a symmetric matrix view is
23-
`U` or `L`) and the body of the if expression is then generated by function `expr_gen`.
24-
The function `expr_gen` receives access pattern description symbol as its argument
25-
and this symbol is then consumed by uplo_access to generate the right code for matrix
26-
element access.
27-
28-
The name of the matrix to test is indicated by `asym`.
29-
"""
30-
function gen_by_access(expr_gen, a::Type{<:StaticVecOrMat}, asym = :wrapped_a)
31-
return expr_gen(:any)
32-
end
33-
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
34-
return quote
35-
if $(asym).uplo == 'U'
36-
$(expr_gen(:up))
37-
else
38-
$(expr_gen(:lo))
39-
end
40-
end
41-
end
42-
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
43-
return quote
44-
if $(asym).uplo == 'U'
45-
$(expr_gen(:up_herm))
46-
else
47-
$(expr_gen(:lo_herm))
48-
end
49-
end
50-
end
51-
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
52-
return expr_gen(:upper_triangular)
53-
end
54-
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
55-
return expr_gen(:lower_triangular)
56-
end
57-
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
58-
return expr_gen(:unit_upper_triangular)
59-
end
60-
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :wrapped_a)
61-
return expr_gen(:unit_lower_triangular)
62-
end
63-
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a)
64-
return expr_gen(:transpose)
65-
end
66-
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a)
67-
return expr_gen(:adjoint)
68-
end
69-
function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a)
70-
return expr_gen(:diagonal)
71-
end
72-
"""
73-
gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray})
74-
75-
Simiar to gen_by_access with only one type argument. The difference is that tests for both
76-
arrays of type `a` and `b` are generated and `expr_gen` receives two access arguments,
77-
first for matrix `a` and the second for matrix `b`.
78-
"""
79-
function gen_by_access(expr_gen, a::Type{<:StaticMatrix}, b::Type)
80-
return quote
81-
return $(gen_by_access(b, :wrapped_b) do access_b
82-
expr_gen(:any, access_b)
83-
end)
84-
end
85-
end
86-
function gen_by_access(expr_gen, a::Type{<:Symmetric{<:Any, <:StaticMatrix}}, b::Type)
87-
return quote
88-
if wrapped_a.uplo == 'U'
89-
return $(gen_by_access(b, :wrapped_b) do access_b
90-
expr_gen(:up, access_b)
91-
end)
92-
else
93-
return $(gen_by_access(b, :wrapped_b) do access_b
94-
expr_gen(:lo, access_b)
95-
end)
96-
end
97-
end
98-
end
99-
function gen_by_access(expr_gen, a::Type{<:Hermitian{<:Any, <:StaticMatrix}}, b::Type)
100-
return quote
101-
if wrapped_a.uplo == 'U'
102-
return $(gen_by_access(b, :wrapped_b) do access_b
103-
expr_gen(:up_herm, access_b)
104-
end)
105-
else
106-
return $(gen_by_access(b, :wrapped_b) do access_b
107-
expr_gen(:lo_herm, access_b)
108-
end)
109-
end
110-
end
111-
end
112-
function gen_by_access(expr_gen, a::Type{<:UpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
113-
return quote
114-
return $(gen_by_access(b, :wrapped_b) do access_b
115-
expr_gen(:upper_triangular, access_b)
116-
end)
117-
end
118-
end
119-
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
120-
return quote
121-
return $(gen_by_access(b, :wrapped_b) do access_b
122-
expr_gen(:lower_triangular, access_b)
123-
end)
124-
end
125-
end
126-
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
127-
return quote
128-
return $(gen_by_access(b, :wrapped_b) do access_b
129-
expr_gen(:unit_upper_triangular, access_b)
130-
end)
131-
end
132-
end
133-
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
134-
return quote
135-
return $(gen_by_access(b, :wrapped_b) do access_b
136-
expr_gen(:unit_lower_triangular, access_b)
137-
end)
138-
end
139-
end
140-
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type)
141-
return quote
142-
return $(gen_by_access(b, :wrapped_b) do access_b
143-
expr_gen(:transpose, access_b)
144-
end)
145-
end
146-
end
147-
function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::Type)
148-
return quote
149-
return $(gen_by_access(b, :wrapped_b) do access_b
150-
expr_gen(:adjoint, access_b)
151-
end)
152-
end
153-
end
154-
function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type)
155-
return quote
156-
return $(gen_by_access(b, :wrapped_b) do access_b
157-
expr_gen(:diagonal, access_b)
158-
end)
159-
end
160-
end
161-
16218
"""
16319
mul_result_structure(a::Type, b::Type)
16420
@@ -202,99 +58,6 @@ function mul_result_structure(::SDiagonal, ::SDiagonal)
20258
return Diagonal
20359
end
20460

205-
"""
206-
uplo_access(sa, asym, k, j, uplo)
207-
208-
Generate code for matrix element access, for a matrix of size `sa` locally referred to
209-
as `asym` in the context where the result will be used. Both indices `k` and `j` need to be
210-
statically known for this function to work. `uplo` is the access pattern mode generated
211-
by the `gen_by_access` function.
212-
"""
213-
function uplo_access(sa, asym, k, j, uplo)
214-
TAsym = Symbol("T"*string(asym))
215-
if uplo == :any
216-
return :($asym[$(LinearIndices(sa)[k, j])])
217-
elseif uplo == :up
218-
if k < j
219-
return :($asym[$(LinearIndices(sa)[k, j])])
220-
elseif k == j
221-
return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :U))
222-
else
223-
return :(transpose($asym[$(LinearIndices(sa)[j, k])]))
224-
end
225-
elseif uplo == :lo
226-
if k > j
227-
return :($asym[$(LinearIndices(sa)[k, j])])
228-
elseif k == j
229-
return :(LinearAlgebra.symmetric($asym[$(LinearIndices(sa)[k, j])], :L))
230-
else
231-
return :(transpose($asym[$(LinearIndices(sa)[j, k])]))
232-
end
233-
elseif uplo == :up_herm
234-
if k < j
235-
return :($asym[$(LinearIndices(sa)[k, j])])
236-
elseif k == j
237-
return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :U))
238-
else
239-
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
240-
end
241-
elseif uplo == :lo_herm
242-
if k > j
243-
return :($asym[$(LinearIndices(sa)[k, j])])
244-
elseif k == j
245-
return :(LinearAlgebra.hermitian($asym[$(LinearIndices(sa)[k, j])], :L))
246-
else
247-
return :(adjoint($asym[$(LinearIndices(sa)[j, k])]))
248-
end
249-
elseif uplo == :upper_triangular
250-
if k <= j
251-
return :($asym[$(LinearIndices(sa)[k, j])])
252-
else
253-
return :(zero($TAsym))
254-
end
255-
elseif uplo == :lower_triangular
256-
if k >= j
257-
return :($asym[$(LinearIndices(sa)[k, j])])
258-
else
259-
return :(zero($TAsym))
260-
end
261-
elseif uplo == :unit_upper_triangular
262-
if k < j
263-
return :($asym[$(LinearIndices(sa)[k, j])])
264-
elseif k == j
265-
return :(oneunit($TAsym))
266-
else
267-
return :(zero($TAsym))
268-
end
269-
elseif uplo == :unit_lower_triangular
270-
if k > j
271-
return :($asym[$(LinearIndices(sa)[k, j])])
272-
elseif k == j
273-
return :(oneunit($TAsym))
274-
else
275-
return :(zero($TAsym))
276-
end
277-
elseif uplo == :upper_hessenberg
278-
if k <= j+1
279-
return :($asym[$(LinearIndices(sa)[k, j])])
280-
else
281-
return :(zero($TAsym))
282-
end
283-
elseif uplo == :transpose
284-
return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])]))
285-
elseif uplo == :adjoint
286-
return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])]))
287-
elseif uplo == :diagonal
288-
if k == j
289-
return :($asym[$k])
290-
else
291-
return :(zero($TAsym))
292-
end
293-
else
294-
error("Unknown uplo: $uplo")
295-
end
296-
end
297-
29861
# Implementations
29962

30063
function mul_smat_vec_exprs(sa, access_a)
@@ -369,31 +132,6 @@ for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTria
369132
@eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}}
370133
end
371134

372-
function combine_products(expr_list)
373-
filtered = filter(expr_list) do expr
374-
if expr.head != :call || expr.args[1] != :*
375-
error("expected call to *")
376-
end
377-
for arg in expr.args[2:end]
378-
if isa(arg, Expr) && arg.head == :call && arg.args[1] == :zero
379-
return false
380-
end
381-
end
382-
return true
383-
end
384-
if isempty(filtered)
385-
return :(zero(T))
386-
else
387-
return reduce(filtered) do ex1, ex2
388-
if ex2.head != :call || ex2.args[1] != :*
389-
error("expected call to *")
390-
end
391-
392-
return :(muladd($(ex2.args[2]), $(ex2.args[3]), $ex1))
393-
end
394-
end
395-
end
396-
397135
@generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
398136
S = Size(sa[1], sb[2])
399137
# Heuristic choice for amount of codegen

0 commit comments

Comments
 (0)