Skip to content

Commit ef64037

Browse files
authored
Improve performance of Kronecker products (#126)
1 parent 51f5997 commit ef64037

File tree

6 files changed

+23
-14
lines changed

6 files changed

+23
-14
lines changed

src/kronecker.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Base.kron(A::KroneckerMap, B::LinearMap) =
5252
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B))
5353
Base.kron(A::KroneckerMap, B::KroneckerMap) =
5454
KroneckerMap{promote_type(eltype(A), eltype(B))}(tuple(A.maps..., B.maps...))
55+
Base.kron(A::ScaledMap, B::LinearMap) = A.λ * kron(A.lmap, B)
56+
Base.kron(A::LinearMap{<:RealOrComplex}, B::ScaledMap) = B.λ * kron(A, B.lmap)
57+
Base.kron(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * kron(A.lmap, B.lmap)
5558
Base.kron(A::LinearMap, B::LinearMap, C::LinearMap, Ds::LinearMap...) =
5659
kron(kron(A, B), C, Ds...)
5760
Base.kron(A::AbstractMatrix, B::LinearMap) = kron(LinearMap(A), B)
@@ -104,9 +107,10 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
104107
# multiplication helper functions
105108
#################
106109

107-
@inline function _kronmul!(y, B, X, At, T)
110+
@inline function _kronmul!(y, B, x, At, T)
108111
na, ma = size(At)
109112
mb, nb = size(B)
113+
X = reshape(x, (nb, na))
110114
v = zeros(T, ma)
111115
temp1 = similar(y, na)
112116
temp2 = similar(y, nb)
@@ -119,14 +123,23 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
119123
end
120124
return y
121125
end
122-
@inline function _kronmul!(y, B, X, At::Union{MatrixMap, UniformScalingMap}, T)
126+
@inline function _kronmul!(y, B, x, At::UniformScalingMap, _)
127+
na, ma = size(At)
128+
mb, nb = size(B)
129+
X = reshape(x, (nb, na))
130+
Y = reshape(y, (mb, ma))
131+
_unsafe_mul!(Y, B, X, At.λ, false)
132+
return y
133+
end
134+
@inline function _kronmul!(y, B, x, At::MatrixMap, _)
123135
na, ma = size(At)
124136
mb, nb = size(B)
137+
X = reshape(x, (nb, na))
125138
Y = reshape(y, (mb, ma))
126139
if nb*ma < mb*na
127-
_unsafe_mul!(Y, B, Matrix(X*At))
140+
_unsafe_mul!(Y, B, X * At.lmap)
128141
else
129-
_unsafe_mul!(Y, Matrix(B*X), _parent(At))
142+
_unsafe_mul!(Y, Matrix(B*X), At.lmap)
130143
end
131144
return y
132145
end
@@ -140,18 +153,14 @@ const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
140153
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
141154
require_one_based_indexing(y)
142155
A, B = L.maps
143-
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
144-
issymmetric = false, ishermitian = false, isposdef = false)
145-
_kronmul!(y, B, X, transpose(A), eltype(L))
156+
_kronmul!(y, B, x, transpose(A), eltype(L))
146157
return y
147158
end
148159
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
149160
require_one_based_indexing(y)
150161
A = first(L.maps)
151162
B = kron(Base.tail(L.maps)...)
152-
X = LinearMap(reshape(x, (size(B, 2), size(A, 2)));
153-
issymmetric = false, ishermitian = false, isposdef = false)
154-
_kronmul!(y, B, X, transpose(A), eltype(L))
163+
_kronmul!(y, B, x, transpose(A), eltype(L))
155164
return y
156165
end
157166
# mixed-product rule, prefer the right if possible

src/uniformscalingmap.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ MulStyle(::UniformScalingMap) = FiveArg()
1717

1818
# properties
1919
Base.size(A::UniformScalingMap) = (A.M, A.M)
20-
_parent(A::UniformScalingMap) = A.λ
2120
Base.isreal(A::UniformScalingMap) = isreal(A.λ)
2221
LinearAlgebra.issymmetric(::UniformScalingMap) = true
2322
LinearAlgebra.ishermitian(A::UniformScalingMap) = isreal(A)

src/wrappedmap.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ Base.:(==)(A::MatrixMap, B::MatrixMap) =
3838

3939
# properties
4040
Base.size(A::WrappedMap) = size(A.lmap)
41-
_parent(A::WrappedMap) = A.lmap
4241
LinearAlgebra.issymmetric(A::WrappedMap) = A._issymmetric
4342
LinearAlgebra.ishermitian(A::WrappedMap) = A._ishermitian
4443
LinearAlgebra.isposdef(A::WrappedMap) = A._isposdef

test/kronecker.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
88
LA = LinearMap(A)
99
LB = LinearMap(B)
1010
LK = @inferred kron(LA, LB)
11+
@test kron(LA, 2LB) isa LinearMaps.ScaledMap
12+
@test kron(3LA, LB) isa LinearMaps.ScaledMap
13+
@test kron(3LA, 2LB) isa LinearMaps.ScaledMap
14+
@test kron(3LA, 2LB).λ == 6
1115
@test_throws ErrorException LinearMaps.KroneckerMap{Float64}((LA, LB))
1216
@test occursin("6×6 LinearMaps.KroneckerMap{$(eltype(LK))}", sprint((t, s) -> show(t, "text/plain", s), LK))
1317
@test @inferred size(LK) == size(K)

test/uniformscalingmap.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
1111
w = similar(v)
1212
Id = @inferred LinearMap(I, 10)
1313
@test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id))
14-
@test LinearMaps._parent(Id) == true
1514
@test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
1615
@test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
1716
@test size(Id) == (10, 10)

test/wrappedmap.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using Test, LinearMaps, LinearAlgebra
77
SB = B'B + I
88
L = @inferred LinearMap{Float64}(A)
99
@test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L))
10-
@test LinearMaps._parent(L) === A
1110
MA = @inferred LinearMap(SA)
1211
MB = @inferred LinearMap(SB)
1312
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex

0 commit comments

Comments
 (0)