|
1 |
| -@inline (\)(a::StaticMatrix{<:Any, <:Any, T}, b::StaticVector{<:Any, T}) where {T} = solve(Size(a), Size(b), a, b) |
| 1 | +@inline (\)(a::StaticMatrix, b::StaticVector) = solve(Size(a), Size(b), a, b) |
| 2 | +@inline (\)(a::Union{UpperTriangular{<:Any, <:StaticMatrix}, LowerTriangular{<:Any, <:StaticMatrix}}, b::StaticVecOrMat) = solve(Size(a.data), Size(b), a, b) |
2 | 3 |
|
3 | 4 | # TODO: Ineffecient but requires some infrastructure (e.g. LU or QR) to make efficient so we fall back on inv for now
|
4 | 5 | @inline solve(::Size, ::Size, a, b) = inv(a) * b
|
5 | 6 |
|
6 |
| -@inline solve(::Size{(1,1)}, ::Size{(1,)}, a, b) = similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1]) |
| 7 | +@inline function solve(::Size{(1,1)}, ::Size{(1,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb} |
| 8 | + @inbounds return similar_type(b, typeof(b[1] \ a[1]))(b[1] \ a[1]) |
| 9 | +end |
7 | 10 |
|
8 | 11 | @inline function solve(::Size{(2,2)}, ::Size{(2,)}, a::StaticMatrix{<:Any, <:Any, Ta}, b::StaticVector{<:Any, Tb}) where {Ta, Tb}
|
9 | 12 | d = det(a)
|
|
26 | 29 | (a[1,2]*a[3,1] - a[1,1]*a[3,2])*b[2] +
|
27 | 30 | (a[1,1]*a[2,2] - a[1,2]*a[2,1])*b[3]) / d )
|
28 | 31 | end
|
| 32 | + |
| 33 | +@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb} |
| 34 | + if sa[1] != sb[1] |
| 35 | + throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])")) |
| 36 | + end |
| 37 | + |
| 38 | + x = [Symbol("x$k") for k = 1:sb[1]] |
| 39 | + expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:sa[1]]))/a[$i, $i]) for i = sb[1]:-1:1] |
| 40 | + |
| 41 | + quote |
| 42 | + @_inline_meta |
| 43 | + T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta)) |
| 44 | + @inbounds $(Expr(:block, expr...)) |
| 45 | + @inbounds return similar_type(b, T)(tuple($(x...))) |
| 46 | + end |
| 47 | +end |
| 48 | + |
| 49 | +@generated function solve(::Size{sa}, ::Size{sb}, a::UpperTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} |
| 50 | + if sa[1] != sb[1] |
| 51 | + throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])")) |
| 52 | + end |
| 53 | + |
| 54 | + x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]] |
| 55 | + expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:sa[1]]))/a[$k1, $k1]) for k1 = sb[1]:-1:1, k2 = 1:sb[2]] |
| 56 | + |
| 57 | + quote |
| 58 | + @_inline_meta |
| 59 | + T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta)) |
| 60 | + @inbounds $(Expr(:block, expr...)) |
| 61 | + @inbounds return similar_type(b, T)(tuple($(x...))) |
| 62 | + end |
| 63 | +end |
| 64 | + |
| 65 | +@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticVector{<:Any, Tb}) where {sa, sb, Ta, Tb} |
| 66 | + if sa[1] != sb[1] |
| 67 | + throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])")) |
| 68 | + end |
| 69 | + |
| 70 | + x = [Symbol("x$k") for k = 1:sb[1]] |
| 71 | + expr = [:($(x[i]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == i ? :(b[$j]) : :(a[$i, $j]*$(x[j])) for j = i:-1:1]))/a[$i, $i]) for i = 1:sb[1]] |
| 72 | + |
| 73 | + quote |
| 74 | + @_inline_meta |
| 75 | + T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta)) |
| 76 | + @inbounds $(Expr(:block, expr...)) |
| 77 | + @inbounds return similar_type(b, T)(tuple($(x...))) |
| 78 | + end |
| 79 | +end |
| 80 | + |
| 81 | +@generated function solve(::Size{sa}, ::Size{sb}, a::LowerTriangular{Ta, <:StaticMatrix{<:Any, <:Any, Ta}}, b::StaticMatrix{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} |
| 82 | + if sa[1] != sb[1] |
| 83 | + throw(DimensionMismatch("right hand side b needs first dimension of size $(sa[1]), has size $(sb[1])")) |
| 84 | + end |
| 85 | + |
| 86 | + x = [Symbol("x$k1$k2") for k1 = 1:sb[1], k2 = 1:sb[2]] |
| 87 | + expr = [:($(x[k1, k2]) = $(reduce((ex1, ex2) -> :(-($ex1,$ex2)), [j == k1 ? :(b[$j, $k2]) : :(a[$k1, $j]*$(x[j, k2])) for j = k1:-1:1]))/a[$k1, $k1]) for k1 = 1:sb[1], k2 = 1:sb[2]] |
| 88 | + |
| 89 | + quote |
| 90 | + @_inline_meta |
| 91 | + T = typeof((zero(Ta)*zero(Tb) + zero(Ta)*zero(Tb))/one(Ta)) |
| 92 | + @inbounds $(Expr(:block, expr...)) |
| 93 | + @inbounds return similar_type(b, T)(tuple($(x...))) |
| 94 | + end |
| 95 | +end |
0 commit comments