|
1 | 1 | import Base: +, -, *, /, \
|
2 | 2 |
|
3 |
| -# TODO: more operators, like AbstractArray |
| 3 | +#-------------------------------------------------- |
| 4 | +# Vector space algebra |
4 | 5 |
|
5 | 6 | # Unary ops
|
6 | 7 | @inline -(a::StaticArray) = map(-, a)
|
@@ -30,10 +31,7 @@ import Base: +, -, *, /, \
|
30 | 31 | @inline -(a::UniformScaling, b::StaticMatrix) = _plus_uniform(Size(b), -b, a.λ)
|
31 | 32 |
|
32 | 33 | @generated function _plus_uniform(::Size{S}, a::StaticMatrix, λ) where {S}
|
33 |
| - if S[1] != S[2] |
34 |
| - throw(DimensionMismatch("matrix is not square: dimensions are $S")) |
35 |
| - end |
36 |
| - n = S[1] |
| 34 | + n = checksquare(a) |
37 | 35 | exprs = [i == j ? :(a[$(LinearIndices(S)[i, j])] + λ) : :(a[$(LinearIndices(S)[i, j])]) for i = 1:n, j = 1:n]
|
38 | 36 | return quote
|
39 | 37 | $(Expr(:meta, :inline))
|
|
46 | 44 | @inline \(a::UniformScaling, b::Union{StaticMatrix,StaticVector}) = a.λ \ b
|
47 | 45 | @inline /(a::StaticMatrix, b::UniformScaling) = a / b.λ
|
48 | 46 |
|
| 47 | +#-------------------------------------------------- |
| 48 | +# Matrix algebra |
49 | 49 |
|
50 | 50 | # Transpose, conjugate, etc
|
51 | 51 | @inline conj(a::StaticArray) = map(conj, a)
|
|
85 | 85 | @inline Base.zero(a::SA) where {SA <: StaticArray} = zeros(SA)
|
86 | 86 | @inline Base.zero(a::Type{SA}) where {SA <: StaticArray} = zeros(SA)
|
87 | 87 |
|
88 |
| -@inline one(::SM) where {SM <: StaticMatrix} = _one(Size(SM), SM) |
89 |
| -@inline one(::Type{SM}) where {SM <: StaticMatrix} = _one(Size(SM), SM) |
90 |
| -@generated function _one(::Size{S}, ::Type{SM}) where {S, SM <: StaticArray} |
91 |
| - if (length(S) != 2) || (S[1] != S[2]) |
92 |
| - error("multiplicative identity defined only for square matrices") |
93 |
| - end |
94 |
| - T = eltype(SM) # should be "hyperpure" |
95 |
| - if T == Any |
96 |
| - T = Float64 |
97 |
| - end |
98 |
| - exprs = [i == j ? :(one($T)) : :(zero($T)) for i ∈ 1:S[1], j ∈ 1:S[2]] |
99 |
| - return quote |
100 |
| - $(Expr(:meta, :inline)) |
101 |
| - SM(tuple($(exprs...))) |
| 88 | +@inline one(m::StaticMatrixLike) = _one(Size(m), m) |
| 89 | +@inline one(::Type{SM}) where {SM<:StaticMatrixLike}= _one(Size(SM), SM) |
| 90 | +function _one(s::Size, m_or_SM) |
| 91 | + if (length(s) != 2) || (s[1] != s[2]) |
| 92 | + throw(DimensionMismatch("multiplicative identity defined only for square matrices")) |
102 | 93 | end
|
| 94 | + _scalar_matrix(s, m_or_SM, one(_eltype_or(m_or_SM, Float64))) |
103 | 95 | end
|
104 | 96 |
|
105 |
| -# StaticMatrix(I::UniformScaling) methods to replace eye |
106 |
| -(::Type{SM})(I::UniformScaling) where {N,M,SM<:StaticMatrix{N,M}} = _eye(Size(SM), SM, I) |
107 |
| - |
108 |
| -@generated function _eye(::Size{S}, ::Type{SM}, I::UniformScaling{T}) where {S, SM <: StaticArray, T} |
109 |
| - exprs = [i == j ? :(I.λ) : :(zero($T)) for i ∈ 1:S[1], j ∈ 1:S[2]] |
| 97 | +# StaticMatrix(I::UniformScaling) |
| 98 | +(::Type{SM})(I::UniformScaling) where {SM<:StaticMatrix} = _scalar_matrix(Size(SM), SM, I.λ) |
| 99 | +# The following oddity is needed if we want `SArray{Tuple{2,3}}(I)` to work |
| 100 | +# because we do not have `SArray{Tuple{2,3}} <: StaticMatrix`. |
| 101 | +(::Type{SM})(I::UniformScaling) where {SM<:(StaticArray{Tuple{N,M}} where {N,M})} = |
| 102 | + _scalar_matrix(Size(SM), SM, I.λ) |
| 103 | + |
| 104 | +# Construct a matrix with the scalar λ on the diagonal and zeros off the |
| 105 | +# diagonal. The matrix can be non-square. |
| 106 | +@generated function _scalar_matrix(s::Size{S}, m_or_SM, λ) where {S} |
| 107 | + elements = Symbol[i == j ? :λ : :λzero for i in 1:S[1], j in 1:S[2]] |
110 | 108 | return quote
|
111 | 109 | $(Expr(:meta, :inline))
|
112 |
| - SM(tuple($(exprs...))) |
| 110 | + λzero = zero(λ) |
| 111 | + _construct_similar(m_or_SM, s, tuple($(elements...))) |
113 | 112 | end
|
114 | 113 | end
|
115 | 114 |
|
|
145 | 144 | end
|
146 | 145 | end
|
147 | 146 |
|
| 147 | +#-------------------------------------------------- |
| 148 | +# Vector products |
148 | 149 | @inline cross(a::StaticVector, b::StaticVector) = _cross(same_size(a, b), a, b)
|
149 | 150 | _cross(::Size{S}, a::StaticVector, b::StaticVector) where {S} = error("Cross product not defined for $(S[1])-vectors")
|
150 | 151 | @inline function _cross(::Size{(2,)}, a::StaticVector, b::StaticVector)
|
|
179 | 180 | return ret
|
180 | 181 | end
|
181 | 182 |
|
| 183 | +#-------------------------------------------------- |
| 184 | +# Norms |
182 | 185 | @inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v))))
|
183 | 186 |
|
184 | 187 | @inline norm(a::StaticArray) = _norm(Size(a), a)
|
|
240 | 243 |
|
241 | 244 | @inline tr(a::StaticMatrix) = _tr(Size(a), a)
|
242 | 245 | @generated function _tr(::Size{S}, a::StaticMatrix) where {S}
|
243 |
| - if S[1] != S[2] |
244 |
| - throw(DimensionMismatch("matrix is not square")) |
245 |
| - end |
| 246 | + checksquare(a) |
246 | 247 |
|
247 | 248 | if S[1] == 0
|
248 | 249 | return :(zero(eltype(a)))
|
|
257 | 258 | end
|
258 | 259 | end
|
259 | 260 |
|
| 261 | + |
| 262 | +#-------------------------------------------------- |
| 263 | +# Outer products |
| 264 | + |
260 | 265 | const _length_limit = Length(200)
|
261 | 266 |
|
262 | 267 | @inline kron(a::StaticMatrix, b::StaticMatrix) = _kron(_length_limit, Size(a), Size(b), a, b)
|
|
414 | 419 | end
|
415 | 420 | end
|
416 | 421 |
|
417 |
| -# some micro-optimizations (TODO check these make sense for v0.6+) |
418 |
| -@inline LinearAlgebra.checksquare(::SM) where {SM<:StaticMatrix} = _checksquare(Size(SM)) |
419 |
| -@inline LinearAlgebra.checksquare(::Type{SM}) where {SM<:StaticMatrix} = _checksquare(Size(SM)) |
420 | 422 |
|
421 |
| -@pure _checksquare(::Size{S}) where {S} = (S[1] == S[2] || throw(DimensionMismatch("matrix is not square: dimensions are $S")); S[1]) |
| 423 | +#-------------------------------------------------- |
| 424 | +# Some shimming for special linear algebra matrix types |
| 425 | +@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, uplo)) |
| 426 | +@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (checksquare(A); Hermitian{eltype(A),typeof(A)}(A, uplo)) |
422 | 427 |
|
423 |
| -@inline LinearAlgebra.Symmetric(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo)) |
424 |
| -@inline LinearAlgebra.Hermitian(A::StaticMatrix, uplo::Char='U') = (LinearAlgebra.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo)) |
|
0 commit comments