Skip to content

Commit f77df91

Browse files
authored
Merge pull request #192 from bshall/solve-triangular
Add simple forward and backward substitution to solve triangular systems
2 parents dd2289c + f746211 commit f77df91

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

src/solve.jl

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
@inline (\)(a::StaticMatrix{<:Any, <:Any, T}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a), Size(b), a, b)
1+
@inline (\)(a::StaticMatrix, b::StaticVector) = solve(Size(a), Size(b), a, b)
2+
@inline (\)(a::Union{UpperTriangular{<:Any, <:StaticMatrix}, LowerTriangular{<:Any, <:StaticMatrix}}, b::StaticVecOrMat) = solve(Size(a.data), Size(b), a, b)
23

34
# TODO: Ineffecient but requires some infrastructure (e.g. LU or QR) to make efficient so we fall back on inv for now
45
@inline solve(::Size, ::Size, a, b) = inv(a) * b
56

6-
@inline solve(::Size{(1,1)}, ::Size{(1,)}, a, b) = similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1])
7+
@inline function solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
8+
@inbounds return similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1])
9+
end
710

811
@inline function solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
912
d = det(a)
@@ -26,3 +29,67 @@ end
2629
(a[1,2]*a[3,1] - a[1,1]*a[3,2])*b[2] +
2730
(a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d )
2831
end
32+
33+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
34+
if sa[1] != sb[1]
35+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
36+
end
37+
38+
x = [Symbol("x$k") for k = 1:sb[1]]
39+
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:sa[1]]))/a[$i, $i]) for i = sb[1]:-1:1]
40+
41+
quote
42+
@_inline_meta
43+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
44+
@inbounds $(Expr(:block, expr...))
45+
@inbounds return similar_type(b, T)(tuple($(x...)))
46+
end
47+
end
48+
49+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
50+
if sa[1] != sb[1]
51+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
52+
end
53+
54+
x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]]
55+
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:sa[1]]))/a[$k1, $k1]) for k1 = sb[1]:-1:1, k2 = 1:sb[2]]
56+
57+
quote
58+
@_inline_meta
59+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
60+
@inbounds $(Expr(:block, expr...))
61+
@inbounds return similar_type(b, T)(tuple($(x...)))
62+
end
63+
end
64+
65+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
66+
if sa[1] != sb[1]
67+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
68+
end
69+
70+
x = [Symbol("x$k") for k = 1:sb[1]]
71+
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:-1:1]))/a[$i, $i]) for i = 1:sb[1]]
72+
73+
quote
74+
@_inline_meta
75+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
76+
@inbounds $(Expr(:block, expr...))
77+
@inbounds return similar_type(b, T)(tuple($(x...)))
78+
end
79+
end
80+
81+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
82+
if sa[1] != sb[1]
83+
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
84+
end
85+
86+
x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]]
87+
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:-1:1]))/a[$k1, $k1]) for k1 = 1:sb[1], k2 = 1:sb[2]]
88+
89+
quote
90+
@_inline_meta
91+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
92+
@inbounds $(Expr(:block, expr...))
93+
@inbounds return similar_type(b, T)(tuple($(x...)))
94+
end
95+
end

test/solve.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,28 @@
3131
@test_throws DimensionMismatch m1\v
3232
@test_throws DimensionMismatch m1\m2
3333
end
34+
35+
@testset "Solving triangular system" begin
36+
for n in (1, 2, 3, 4),
37+
(t, uplo) in ((UpperTriangular, :U),
38+
(LowerTriangular, :L)),
39+
(m, v, u) in ((SMatrix{n,n}, SVector{n}, SMatrix{n,2}),
40+
(MMatrix{n,n}, MVector{n}, SMatrix{n,2})),
41+
eltya in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloat}, Int),
42+
eltyb in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloat})
43+
44+
A = t(eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, (eltya <: Complex ? complex.(randn(n,n), randn(n,n)) : randn(n,n)) |> z -> chol(z'z) |> z -> uplo == :U ? z : ctranspose(z)))
45+
b = convert(Matrix{eltyb}, eltya <: Complex ? real(A)*ones(n,2) : A*ones(n,2))
46+
SA = t(m(A.data))
47+
Sx = SA \ v(b[:, 1])
48+
x = A \ b[:, 1]
49+
@test Sx isa StaticVector # test not falling back to Base
50+
@test Sx x
51+
@test eltype(Sx) == eltype(x)
52+
SX = SA \ u(b)
53+
X = A \ b
54+
@test SX isa StaticMatrix # test not falling back to Base
55+
@test SX X
56+
@test eltype(SX) == eltype(X)
57+
end
58+
end

0 commit comments

Comments
 (0)