Skip to content

Commit 0f4853e

Browse files
committed
add constructors from UniformScaling, to replace eye
1 parent 372c1de commit 0f4853e

File tree

7 files changed

+56
-16
lines changed

7 files changed

+56
-16
lines changed

src/MArray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ end
8282
@inline one(::Type{MArray{S,T}}) where {S,T} = one(MArray{S,T,tuple_length(S)})
8383
@inline eye(::Type{MArray{S,T}}) where {S,T} = eye(MArray{S,T,tuple_length(S)})
8484

85+
# MArray(I::UniformScaling) methods to replace eye
86+
(::Type{MA})(I::UniformScaling) where {MA<:MArray} = _eye(Size(MA), MA, I)
87+
8588
####################
8689
## MArray methods ##
8790
####################

src/SArray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ end
6161
@inline one(::Type{SArray{S, T}}) where {S, T} = one(SArray{S, T, tuple_length(S)})
6262
@inline eye(::Type{SArray{S, T}}) where {S, T} = eye(SArray{S, T, tuple_length(S)})
6363

64+
# SArray(I::UniformScaling) methods to replace eye
65+
(::Type{SA})(I::UniformScaling) where {SA<:SArray} = _eye(Size(SA), SA, I)
66+
6467
####################
6568
## SArray methods ##
6669
####################

src/SDiagonal.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ function logdet(D::SDiagonal{N,T}) where {N,T<:Complex} #Make sure branch cut is
8282
end
8383

8484
eye(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T}))
85+
# SDiagonal(I::UniformScaling) methods to replace eye
86+
(::Type{SD})(I::UniformScaling) where {N,SD<:SDiagonal{N}} = SD(ntuple(x->I.λ, Val(N)))
87+
8588
one(::Type{SDiagonal{N,T}}) where {N,T} = SDiagonal(ones(SVector{N,T}))
8689
one(::SDiagonal{N,T}) where {N,T} = SDiagonal(ones(SVector{N,T}))
8790
Base.zero(::SDiagonal{N,T}) where {N,T} = SDiagonal(zeros(SVector{N,T}))

src/linalg.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,17 @@ end
169169
@deprecate eye(A::SM) where {SM<:StaticMatrix} eye(typeof(A))
170170
end
171171

172-
#if VERSION < v"0.7-"
173-
@inline eye(::Type{SM}) where {SM <: StaticMatrix} = _eye(Size(SM), SM)
174-
@generated function _eye(::Size{S}, ::Type{SM}) where {S, SM <: StaticArray}
175-
T = eltype(SM) # should be "hyperpure"
176-
if T == Any
177-
T = Float64
178-
end
179-
exprs = [i == j ? :(one($T)) : :(zero($T)) for i 1:S[1], j 1:S[2]]
180-
return quote
181-
$(Expr(:meta, :inline))
182-
SM(tuple($(exprs...)))
183-
end
172+
# StaticMatrix(I::UniformScaling) methods to replace eye
173+
(::Type{SM})(I::UniformScaling) where {N,M,SM<:StaticMatrix{N,M}} = _eye(Size(SM), SM, I)
174+
175+
@inline eye(::Type{SM}) where {SM<:StaticMatrix} = _eye(Size(SM), SM, 1.0I)
176+
@generated function _eye(::Size{S}, ::Type{SM}, I::UniformScaling{T}) where {S, SM <: StaticArray, T}
177+
exprs = [i == j ? :(I.λ) : :(zero($T)) for i 1:S[1], j 1:S[2]]
178+
return quote
179+
$(Expr(:meta, :inline))
180+
SM(tuple($(exprs...)))
184181
end
185-
#end
182+
end
186183

187184
@generated function diagm(kvs::Pair{<:Val,<:StaticVector}...)
188185
N = maximum(abs(kv.parameters[1].parameters[1]) + length(kv.parameters[2]) for kv in kvs)

test/MArray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@
127127
@test !Base.mightalias(m, copy(m))
128128
@test Base.mightalias(m, view(m, :, 1))
129129
end
130-
130+
131131
if isdefined(Base, :dataids) # v0.7-
132-
@test Base.dataids(m) == (UInt(pointer(m)),)
132+
@test Base.dataids(m) == (UInt(pointer(m)),)
133133
end
134134
end
135135

test/SArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
@test ((@SArray [i*j*k*l for i = 1:2, j = 2:3, k = 3:4, l = 1:2])::SArray{Tuple{2,2,2,2}}).data === (6, 12, 9, 18, 8, 16, 12, 24, 12, 24, 18, 36, 16, 32, 24, 48)
4444
@test ((@SArray [i*j*k*l*m for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2])::SArray{Tuple{2,2,2,2,2}}).data === (6, 12, 9, 18, 8, 16, 12, 24, 12, 24, 18, 36, 16, 32, 24, 48, 2*6, 2*12, 2*9, 2*18, 2*8, 2*16, 2*12, 2*24, 2*12, 2*24, 2*18, 2*36, 2*16, 2*32, 2*24, 2*48)
4545
@test ((@SArray [1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2])::SArray{Tuple{2,2,2,2,2,2}}).data === ntuple(i->1, 64)
46-
@test ((@SArray [1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2])::SArray{Tuple{2,2,2,2,2,2,2}}).data === ntuple(i->1, 128)
46+
@test ((@SArray [1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2])::SArray{Tuple{2,2,2,2,2,2,2}}).data === ntuple(i->1, 128)
4747
@test ((@SArray [1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2])::SArray{Tuple{2,2,2,2,2,2,2,2}}).data === ntuple(i->1, 256)
4848
test_expand_error(:(@SArray [1 for i = 1:2, j = 2:3, k = 3:4, l = 1:2, m = 1:2, n = 1:2, o = 1:2, p = 1:2, q = 1:2]))
4949
@test ((@SArray Float64[i for i = 1:2])::SArray{Tuple{2}}).data === (1.0, 2.0)

test/linalg.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,40 @@ using StaticArrays, Test, LinearAlgebra
3939
@test @inferred(I \ @SMatrix([0 1; 2 3])) === @SMatrix [0.0 1.0; 2.0 3.0]
4040
end
4141

42+
@testset "Constructors from UniformScaling" begin
43+
I3x3 = Matrix(I, 3, 3)
44+
I3x2 = Matrix(I, 3, 2)
45+
# SArray
46+
## eltype from I
47+
@test @inferred(SArray{Tuple{3,3}}(I))::SMatrix{3,3,Bool,9} == I3x3
48+
@test @inferred(SArray{Tuple{3,3}}(2.0I))::SMatrix{3,3,Float64,9} == 2I3x3
49+
## eltype from constructor
50+
@test @inferred(SArray{Tuple{3,3},Float64}(I))::SMatrix{3,3,Float64,9} == I3x3
51+
@test @inferred(SArray{Tuple{3,3},Float32}(2.0I))::SMatrix{3,3,Float32,9} == 2I3x3
52+
## non-square
53+
@test @inferred(SArray{Tuple{3,2}}(I))::SMatrix{3,2,Bool,6} == I3x2
54+
# SMatrix
55+
@test @inferred(SMatrix{3,3}(I))::SMatrix{3,3,Bool,9} == I3x3
56+
@test @inferred(SMatrix{3,3}(2.0I))::SMatrix{3,3,Float64,9} == 2I3x3
57+
# MArray
58+
## eltype from I
59+
@test @inferred(MArray{Tuple{3,3}}(I))::MMatrix{3,3,Bool,9} == I3x3
60+
@test @inferred(MArray{Tuple{3,3}}(2.0I))::MMatrix{3,3,Float64,9} == 2I3x3
61+
## eltype from constructor
62+
@test @inferred(MArray{Tuple{3,3},Float64}(I))::MMatrix{3,3,Float64,9} == I3x3
63+
@test @inferred(MArray{Tuple{3,3},Float32}(2.0I))::MMatrix{3,3,Float32,9} == 2I3x3
64+
## non-square
65+
@test @inferred(MArray{Tuple{3,2}}(I))::MMatrix{3,2,Bool,6} == I3x2
66+
# MMatrix
67+
@test @inferred(MMatrix{3,3}(I))::MMatrix{3,3,Bool,9} == I3x3
68+
@test @inferred(MMatrix{3,3}(2.0I))::MMatrix{3,3,Float64,9} == 2I3x3
69+
# SDiagonal
70+
@test @inferred(SDiagonal{3}(I))::SDiagonal{3,Bool} == I3x3
71+
@test @inferred(SDiagonal{3}(2.0I))::SDiagonal{3,Float64} == 2I3x3
72+
@test @inferred(SDiagonal{3,Float64}(I))::SDiagonal{3,Float64} == I3x3
73+
@test @inferred(SDiagonal{3,Float32}(2.0I))::SDiagonal{3,Float32} == 2I3x3
74+
end
75+
4276
@testset "diagm()" begin
4377
@test @inferred(diagm(Val(0) => SVector(1,2))) === @SMatrix [1 0; 0 2]
4478
@test @inferred(diagm(Val(2) => SVector(1,2,3)))::SMatrix == diagm(2 => [1,2,3])

0 commit comments

Comments
 (0)