Skip to content

Commit 51f5997

Browse files
authored
Restrict ScaledMap to RealOrComplex (#127)
1 parent 7ed1d16 commit 51f5997

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

src/scaledmap.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
1-
struct ScaledMap{T, S<:RealOrComplex, A<:LinearMap} <: LinearMap{T}
1+
"""
2+
Lazy representation of a scaled map `λ * A = A * λ` with real or complex map
3+
`A <: LinearMap{RealOrComplex}` and real or complex scaling factor
4+
`λ <: RealOrComplex`.
5+
"""
6+
struct ScaledMap{T, S<:RealOrComplex, L<:LinearMap} <: LinearMap{T}
27
λ::S
3-
lmap::A
4-
function ScaledMap{T}::S, lmap::A) where {T, S <: RealOrComplex, A <: LinearMap}
5-
@assert Base.promote_op(*, S, eltype(lmap)) == T "target type $T cannot hold products of $S and $(eltype(lmap)) objects"
6-
new{T,S,A}(λ, lmap)
8+
lmap::L
9+
function ScaledMap{T}::S, A::L) where {T, S <: RealOrComplex, L <: LinearMap{<:RealOrComplex}}
10+
@assert Base.promote_op(*, S, eltype(A)) == T "target type $T cannot hold products of $S and $(eltype(A)) objects"
11+
new{T,S,L}(λ, A)
712
end
813
end
914

1015
# constructor
11-
ScaledMap::S, lmap::A) where {S<:RealOrComplex,A<:LinearMap} =
12-
ScaledMap{Base.promote_op(*, S, eltype(lmap))}(λ, lmap)
16+
ScaledMap::RealOrComplex, lmap::LinearMap{<:RealOrComplex}) =
17+
ScaledMap{Base.promote_op(*, typeof(λ), eltype(lmap))}(λ, lmap)
1318

1419
# basic methods
1520
Base.size(A::ScaledMap) = size(A.lmap)
@@ -26,8 +31,8 @@ Base.:(==)(A::ScaledMap, B::ScaledMap) =
2631
eltype(A) == eltype(B) && A.lmap == B.lmap && A.λ == B.λ
2732

2833
# scalar multiplication and division
29-
Base.:(*)(α::RealOrComplex, A::LinearMap) = ScaledMap(α, A)
30-
Base.:(*)(A::LinearMap, α::RealOrComplex) = ScaledMap(α, A)
34+
Base.:(*)(α::RealOrComplex, A::LinearMap{<:RealOrComplex}) = ScaledMap(α, A)
35+
Base.:(*)(A::LinearMap{<:RealOrComplex}, α::RealOrComplex) = ScaledMap(α, A)
3136

3237
Base.:(*)(α::Number, A::ScaledMap) =* A.λ) * A.lmap
3338
Base.:(*)(A::ScaledMap, α::Number) = A.lmap * (A.λ * α)
@@ -42,13 +47,19 @@ Base.:(*)(A::ScaledMap, B::LinearMap) = A.λ * (A.lmap * B)
4247
Base.:(*)(A::LinearMap, B::ScaledMap) = (A * B.lmap) * B.λ
4348

4449
# multiplication with vectors/matrices
45-
for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
50+
for (In, Out) in ((AbstractVector, AbstractVecOrMat),
51+
(AbstractMatrix, AbstractMatrix))
4652
@eval begin
47-
function _unsafe_mul!(y::$Out, A::ScaledMap, x::$In)
53+
# commutative case
54+
function _unsafe_mul!(y::$Out, A::ScaledMap, x::$In{<:RealOrComplex})
4855
return _unsafe_mul!(y, A.lmap, x, A.λ, false)
4956
end
50-
function _unsafe_mul!(y::$Out, A::ScaledMap, x::$In, α::Number, β::Number)
57+
function _unsafe_mul!(y::$Out, A::ScaledMap, x::$In{<:RealOrComplex}, α::Number, β::Number)
5158
return _unsafe_mul!(y, A.lmap, x, A.λ * α, β)
5259
end
60+
# non-commutative case
61+
function _unsafe_mul!(y::$Out, A::ScaledMap, x::$In)
62+
return lmul!(A.λ, _unsafe_mul!(y, A.lmap, x))
63+
end
5364
end
5465
end

test/numbertypes.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
using Test, LinearMaps, LinearAlgebra, Quaternions
22

3+
# type piracy because Quaternions.jl doesn't have it right
4+
Base.:(*)(z::Complex{T}, q::Quaternion{T}) where {T<:Real} = quat(z) * q
5+
Base.:(*)(q::Quaternion{T}, z::Complex{T}) where {T<:Real} = q * quat(z)
6+
37
@testset "noncommutative number type" begin
48
x = Quaternion.(rand(10), rand(10), rand(10), rand(10))
59
v = rand(10)
610
A = Quaternion.(rand(10,10), rand(10,10), rand(10,10), rand(10,10))
11+
B = rand(ComplexF64, 10, 10)
712
γ = Quaternion.(rand(4)...) # "Number"
813
α = UniformScaling(γ)
914
β = UniformScaling(Quaternion.(rand(4)...))
15+
λ = rand(ComplexF64)
1016
L = LinearMap(A)
1117
@test Array(L) == A
1218
@test Array(L') == A'
@@ -21,8 +27,14 @@ using Test, LinearMaps, LinearAlgebra, Quaternions
2127
@test L' * x A' * x
2228
@test α * (L * x) α * (A * x)
2329
@test α * L * x α * A * x
30+
@test L * α * x A * α * x
2431
@test 3L * x 3A * x
2532
@test 3L' * x 3A' * x
33+
@test λ*L isa LinearMaps.CompositeMap
34+
@test γ ** LinearMap(B)) isa LinearMaps.CompositeMap
35+
@test* LinearMap(B)) * γ isa LinearMaps.CompositeMap
36+
@test λ*L * x λ*A * x
37+
@test λ*L' * x λ*A' * x
2638
@test α * (3L * x) α * (3A * x)
2739
@test (@inferred α * 3L) * x α * 3A * x
2840
@test (@inferred 3L * α) * x 3A * α * x
@@ -53,6 +65,13 @@ using Test, LinearMaps, LinearAlgebra, Quaternions
5365
@test Array(-L) == -A
5466
@test Array\ L) γ \ A
5567
@test Array(L / γ) A / γ
68+
M = rand(ComplexF64, 10, 10); α = rand(ComplexF64);
69+
y = α * M * x; Y = α * M * A
70+
@test* LinearMap(M)) * x (quat(α) * LinearMap(M)) * x y
71+
@test mul!(copy(y), α * LinearMap(M), x, α, false) α * M * x * α
72+
@test mul!(copy(y), α * LinearMap(M), x, quat(α), false) α * M * x * α
73+
@test mul!(copy(Y), α * LinearMap(M), A) α * M * A
74+
@test mul!(copy(Y), α * LinearMap(M), A, α, false) α * M * A * α
5675
end
5776

5877
@testset "nonassociative number type" begin

0 commit comments

Comments
 (0)