Skip to content

Commit 9fece75

Browse files
committed
more matrix types for multiplication
1 parent 49d1043 commit 9fece75

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

src/matrix_multiply.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ const StaticMatMulLike{s1, s2, T} = Union{
1515
Hermitian{T, <:StaticMatrix{s1, s2, T}},
1616
LowerTriangular{T, <:StaticMatrix{s1, s2, T}},
1717
UpperTriangular{T, <:StaticMatrix{s1, s2, T}},
18+
UnitLowerTriangular{T, <:StaticMatrix{s1, s2, T}},
19+
UnitUpperTriangular{T, <:StaticMatrix{s1, s2, T}},
20+
UpperHessenberg{T, <:StaticMatrix{s1, s2, T}},
1821
Adjoint{T, <:StaticMatrix{s1, s2, T}},
1922
Transpose{T, <:StaticMatrix{s1, s2, T}}}
2023

@@ -67,6 +70,15 @@ end
6770
function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
6871
return expr_gen(:lower_triangular)
6972
end
73+
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, asym = :a)
74+
return expr_gen(:unit_upper_triangular)
75+
end
76+
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, asym = :a)
77+
return expr_gen(:unit_lower_triangular)
78+
end
79+
function gen_by_access(expr_gen, a::Type{<:UpperHessenberg{<:Any, <:StaticMatrix}}, asym = :a)
80+
return expr_gen(:upper_hessenberg)
81+
end
7082
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticVecOrMat}}, asym = :a)
7183
return expr_gen(:transpose)
7284
end
@@ -127,6 +139,27 @@ function gen_by_access(expr_gen, a::Type{<:LowerTriangular{<:Any, <:StaticMatrix
127139
end)
128140
end
129141
end
142+
function gen_by_access(expr_gen, a::Type{<:UnitUpperTriangular{<:Any, <:StaticMatrix}}, b::Type)
143+
return quote
144+
return $(gen_by_access(b, :b) do access_b
145+
expr_gen(:unit_upper_triangular, access_b)
146+
end)
147+
end
148+
end
149+
function gen_by_access(expr_gen, a::Type{<:UnitLowerTriangular{<:Any, <:StaticMatrix}}, b::Type)
150+
return quote
151+
return $(gen_by_access(b, :b) do access_b
152+
expr_gen(:unit_lower_triangular, access_b)
153+
end)
154+
end
155+
end
156+
function gen_by_access(expr_gen, a::Type{<:UpperHessenberg{<:Any, <:StaticMatrix}}, b::Type)
157+
return quote
158+
return $(gen_by_access(b, :b) do access_b
159+
expr_gen(:upper_hessenberg, access_b)
160+
end)
161+
end
162+
end
130163
function gen_by_access(expr_gen, a::Type{<:Transpose{<:Any, <:StaticMatrix}}, b::Type)
131164
return quote
132165
return $(gen_by_access(b, :b) do access_b
@@ -205,6 +238,28 @@ function uplo_access(sa, asym, k, j, uplo)
205238
else
206239
return :(zero(T))
207240
end
241+
elseif uplo == :unit_upper_triangular
242+
if k < j
243+
return :($asym[$(LinearIndices(sa)[k, j])])
244+
elseif k == j
245+
return :(oneunit(T))
246+
else
247+
return :(zero(T))
248+
end
249+
elseif uplo == :unit_lower_triangular
250+
if k > j
251+
return :($asym[$(LinearIndices(sa)[k, j])])
252+
elseif k == j
253+
return :(oneunit(T))
254+
else
255+
return :(zero(T))
256+
end
257+
elseif uplo == :upper_hessenberg
258+
if k <= j+1
259+
return :($asym[$(LinearIndices(sa)[k, j])])
260+
else
261+
return :(zero(T))
262+
end
208263
elseif uplo == :transpose
209264
return :($asym[$(LinearIndices(reverse(sa))[j, k])])
210265
elseif uplo == :ajoint

test/matrix_multiply.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ mul_wrappers = [
88
m -> Hermitian(m, :L),
99
m -> UpperTriangular(m),
1010
m -> LowerTriangular(m),
11+
m -> UnitUpperTriangular(m),
12+
m -> UnitLowerTriangular(m),
13+
m -> UpperHessenberg(m),
1114
m -> adjoint(m),
1215
m -> transpose(m)]
1316

0 commit comments

Comments
 (0)