Skip to content

Commit c37be94

Browse files
committed
Remove type constraints from \ and add more comprehensive tests
1 parent ac580a6 commit c37be94

File tree

2 files changed

+42
-33
lines changed

2 files changed

+42
-33
lines changed

src/solve.jl

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
@inline (\)(a::StaticMatrix{<:Any, <:Any, T}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a), Size(b), a, b)
2-
@inline (\)(a::Union{UpperTriangular{T, S}, LowerTriangular{T, S}} where {S<:StaticMatrix{<:Any, <:Any, T}}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a.data), Size(b), a, b)
3-
@inline (\)(a::Union{UpperTriangular{T, S}, LowerTriangular{T, S}} where {S<:StaticMatrix{<:Any, <:Any, T}}, b::StaticMatrix{<:Any, <:Any, T}) where {T} = solve(Size(a.data), Size(b), a, b)
1+
@inline (\)(a::StaticMatrix, b::StaticVector) = solve(Size(a), Size(b), a, b)
2+
@inline (\)(a::Union{UpperTriangular{<:Any, <:StaticMatrix}, LowerTriangular{<:Any, <:StaticMatrix}}, b::StaticVecOrMat) = solve(Size(a.data), Size(b), a, b)
43

54
# TODO: Ineffecient but requires some infrastructure (e.g. LU or QR) to make efficient so we fall back on inv for now
65
@inline solve(::Size, ::Size, a, b) = inv(a) * b
@@ -31,7 +30,7 @@ end
3130
(a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d )
3231
end
3332

34-
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
33+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
3534
if sa[1] != sb[1]
3635
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
3736
end
@@ -40,14 +39,14 @@ end
4039
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:sa[1]]))/a[$i, $i]) for i = sb[1]:-1:1]
4140

4241
quote
43-
@_inline_meta
44-
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
45-
@inbounds $(Expr(:block, expr...))
46-
@inbounds return similar_type(b, T)(tuple($(x...)))
42+
@_inline_meta
43+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
44+
@inbounds $(Expr(:block, expr...))
45+
@inbounds return similar_type(b, T)(tuple($(x...)))
4746
end
4847
end
4948

50-
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
49+
@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
5150
if sa[1] != sb[1]
5251
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
5352
end
@@ -56,14 +55,14 @@ end
5655
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:sa[1]]))/a[$k1, $k1]) for k1 = sb[1]:-1:1, k2 = 1:sb[2]]
5756

5857
quote
59-
@_inline_meta
60-
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
61-
@inbounds $(Expr(:block, expr...))
62-
@inbounds return similar_type(b, T)(tuple($(x...)))
58+
@_inline_meta
59+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
60+
@inbounds $(Expr(:block, expr...))
61+
@inbounds return similar_type(b, T)(tuple($(x...)))
6362
end
6463
end
6564

66-
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
65+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb}
6766
if sa[1] != sb[1]
6867
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
6968
end
@@ -72,14 +71,14 @@ end
7271
expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:-1:1]))/a[$i, $i]) for i = 1:sb[1]]
7372

7473
quote
75-
@_inline_meta
76-
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
77-
@inbounds $(Expr(:block, expr...))
78-
@inbounds return similar_type(b, T)(tuple($(x...)))
74+
@_inline_meta
75+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
76+
@inbounds $(Expr(:block, expr...))
77+
@inbounds return similar_type(b, T)(tuple($(x...)))
7978
end
8079
end
8180

82-
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, Sa} where {Sa<:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
81+
@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb}
8382
if sa[1] != sb[1]
8483
throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])"))
8584
end
@@ -88,9 +87,9 @@ end
8887
expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:-1:1]))/a[$k1, $k1]) for k1 = 1:sb[1], k2 = 1:sb[2]]
8988

9089
quote
91-
@_inline_meta
92-
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
93-
@inbounds $(Expr(:block, expr...))
94-
@inbounds return similar_type(b, T)(tuple($(x...)))
90+
@_inline_meta
91+
T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta))
92+
@inbounds $(Expr(:block, expr...))
93+
@inbounds return similar_type(b, T)(tuple($(x...)))
9594
end
9695
end

test/solve.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,28 @@
2727
end
2828

2929
@testset "Solving triangular system" begin
30-
for n in (1,2,3,4),
31-
(t1, uplo1) in ((UpperTriangular, :U),
32-
(LowerTriangular, :L)),
33-
(m, v, u) in ((SMatrix{n, n}, SVector{n}, SMatrix{n, 2}), (MMatrix{n,n}, MVector{n}, SMatrix{n, 2})),
34-
elty in (Float32, Float64, Int)
30+
for n in (1, 2, 3, 4),
31+
(t, uplo) in ((UpperTriangular, :U),
32+
(LowerTriangular, :L)),
33+
(m, v, u) in ((SMatrix{n,n}, SVector{n}, SMatrix{n, 2}),
34+
(MMatrix{n,n}, MVector{n}, SMatrix{n, 2})),
35+
eltya in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloat}, Int),
36+
eltyb in (Float32, Float64, BigFloat, Complex64, Complex128, Complex{BigFloat})
3537

3638
eval(quote
37-
A = $(t1)($elty == Int ? rand(1:7, $n, $n) : convert(Matrix{$elty}, randn($n, $n)) |> t -> chol(t't) |> t -> $(uplo1 == :U) ? t : ctranspose(t))
38-
b = convert(Matrix{$elty}, A*ones($n, 2))
39-
SA = $t1($m(A.data))
40-
@test SA \ $v(b[:, 1]) A \ b[:, 1]
41-
@test SA \ $u(b) A \ b
39+
A = $t($eltya == Int ? rand(1:7, $n, $n) : convert(Matrix{$eltya}, ($eltya <: Complex ? complex.(randn($n, $n), randn($n, $n)) : randn($n, $n)) |> z -> chol(z'z) |> z -> $(uplo == :U) ? z : ctranspose(z)))
40+
b = convert(Matrix{$eltyb}, $eltya <: Complex ? real(A)*ones($n, 2) : A*ones($n, 2))
41+
SA = $t($m(A.data))
42+
Sx = SA \ $v(b[:, 1])
43+
x = A \ b[:, 1]
44+
@test typeof(Sx) <: StaticVector # test not falling back to Base
45+
@test Sx x
46+
@test eltype(Sx) == eltype(x)
47+
SX = SA \ $u(b)
48+
X = A \ b
49+
@test typeof(SX) <: StaticMatrix # test not falling back to Base
50+
@test SX X
51+
@test eltype(SX) == eltype(X)
4252
end)
4353
end
4454
end

0 commit comments

Comments
 (0)