Skip to content

Commit ac580a6

Browse files
committed
Add simple forward and backward substitution to solve triangular systems
1 parent 3be9661 commit ac580a6

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

src/solve.jl

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
@inline (\)(a::StaticMatrix{<:Any, <:Any, T}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a), Size(b), a, b)
2+
@inline (\)(a::Union{UpperTriangular{T, S}, LowerTriangular{T, S}} where {S<:StaticMatrix{<:Any, <:Any, T}}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a.data), Size(b), a, b)
3+
@inline (\)(a::Union{UpperTriangular{T, S}, LowerTriangular{T, S}} where {S<:StaticMatrix{<:Any, <:Any, T}}, b::StaticMatrix{<:Any, <:Any, T}) where {T} = solve(Size(a.data), Size(b), a, b)
24

35
# TODO: Ineffecient but requires some infrastructure (e.g. LU or QR) to make efficient so we fall back on inv for now
46
@inline solve(::Size, ::Size, a, b) = inv(a) * b
57

6-
@inline solve(::Size{(1,1)}, ::Size{(1,)}, a, b) = similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1])
8+
@inline function solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
9+
@inbounds return similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1])
10+
end
711

812
@inline function solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
913
d = det(a)
@@ -26,3 +30,67 @@ end
2630
(a[1,2]*a[3,1] - a[1,1]*a[3,2])*b[2] +
2731
(a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d )
2832
end
33+
34+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
35+
if sa[1] != sb[1]
36+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
37+
end
38+
39+
x = [Symbol("x$k") for k = 1:sb[1]]
40+
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:sa[1]]))/a[$i, $i]) for i = sb[1]:-1:1]
41+
42+
quote
43+
@_inline_meta
44+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
45+
@inbounds $(Expr(:block, expr...))
46+
@inbounds return similar_type(b, T)(tuple($(x...)))
47+
end
48+
end
49+
50+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
51+
if sa[1] != sb[1]
52+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
53+
end
54+
55+
x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]]
56+
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:sa[1]]))/a[$k1, $k1]) for k1 = sb[1]:-1:1, k2 = 1:sb[2]]
57+
58+
quote
59+
@_inline_meta
60+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
61+
@inbounds $(Expr(:block, expr...))
62+
@inbounds return similar_type(b, T)(tuple($(x...)))
63+
end
64+
end
65+
66+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
67+
if sa[1] != sb[1]
68+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
69+
end
70+
71+
x = [Symbol("x$k") for k = 1:sb[1]]
72+
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:-1:1]))/a[$i, $i]) for i = 1:sb[1]]
73+
74+
quote
75+
@_inline_meta
76+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
77+
@inbounds $(Expr(:block, expr...))
78+
@inbounds return similar_type(b, T)(tuple($(x...)))
79+
end
80+
end
81+
82+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
83+
if sa[1] != sb[1]
84+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
85+
end
86+
87+
x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]]
88+
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:-1:1]))/a[$k1, $k1]) for k1 = 1:sb[1], k2 = 1:sb[2]]
89+
90+
quote
91+
@_inline_meta
92+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
93+
@inbounds $(Expr(:block, expr...))
94+
@inbounds return similar_type(b, T)(tuple($(x...)))
95+
end
96+
end

test/solve.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,20 @@
2525
@test m(A)\v(b) ≈ A\b
2626
end =#
2727
end
28+
29+
@testset "Solving triangular system" begin
30+
for n in (1,2,3,4),
31+
(t1, uplo1) in ((UpperTriangular, :U),
32+
(LowerTriangular, :L)),
33+
(m, v, u) in ((SMatrix{n, n}, SVector{n}, SMatrix{n, 2}), (MMatrix{n,n}, MVector{n}, SMatrix{n, 2})),
34+
elty in (Float32, Float64, Int)
35+
36+
eval(quote
37+
A = $(t1)($elty == Int ? rand(1:7, $n, $n) : convert(Matrix{$elty}, randn($n, $n)) |> t -> chol(t't) |> t -> $(uplo1 == :U) ? t : ctranspose(t))
38+
b = convert(Matrix{$elty}, A*ones($n, 2))
39+
SA = $t1($m(A.data))
40+
@test SA \ $v(b[:, 1]) A \ b[:, 1]
41+
@test SA \ $u(b) A \ b
42+
end)
43+
end
44+
end

0 commit comments

Comments
 (0)