Skip to content

Commit cabbf44

Browse files
authored
DiagonalTensorMap performance and convenience specializations (#249)
1 parent c37cf86 commit cabbf44

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

src/tensors/diagonal.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ storagetype(::Type{<:DiagonalTensorMap{T,S,A}}) where {T,S,A<:DenseVector{T}} =
3737
3838
Construct a `DiagonalTensorMap` with uninitialized data.
3939
"""
40+
function DiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T}
41+
(numin(V) == numout(V) == 1 && domain(V) == codomain(V)) ||
42+
throw(ArgumentError("DiagonalTensorMap requires a space with equal domain and codomain and 2 indices"))
43+
return DiagonalTensorMap{T}(undef, domain(V))
44+
end
45+
function DiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T}
46+
length(V) == 1 ||
47+
throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`"))
48+
return DiagonalTensorMap{T}(undef, only(V))
49+
end
4050
function DiagonalTensorMap{T}(::UndefInitializer, V::S) where {T,S<:IndexSpace}
4151
return DiagonalTensorMap{T,S,Vector{T}}(undef, V)
4252
end
@@ -265,6 +275,22 @@ function LinearAlgebra.mul!(dC::DiagonalTensorMap,
265275
return dC
266276
end
267277

278+
function LinearAlgebra.lmul!(D::DiagonalTensorMap, t::AbstractTensorMap)
279+
domain(D) == codomain(t) || throw(SpaceMismatch())
280+
for (c, b) in blocks(t)
281+
lmul!(block(D, c), b)
282+
end
283+
return t
284+
end
285+
286+
function LinearAlgebra.rmul!(t::AbstractTensorMap, D::DiagonalTensorMap)
287+
codomain(D) == domain(t) || throw(SpaceMismatch())
288+
for (c, b) in blocks(t)
289+
rmul!(b, block(D, c))
290+
end
291+
return t
292+
end
293+
268294
Base.inv(d::DiagonalTensorMap) = DiagonalTensorMap(inv.(d.data), d.domain)
269295
function Base.:\(d1::DiagonalTensorMap, d2::DiagonalTensorMap)
270296
d1.domain == d2.domain || throw(SpaceMismatch())
@@ -339,6 +365,17 @@ function _compute_svddata!(d::DiagonalTensorMap, alg::Union{SVD,SDD})
339365
return SVDdata, dims
340366
end
341367

368+
function LinearAlgebra.svdvals(d::DiagonalTensorMap)
369+
return SectorDict(c => LinearAlgebra.svdvals(b) for (c, b) in blocks(d))
370+
end
371+
function LinearAlgebra.eigvals(d::DiagonalTensorMap)
372+
return SectorDict(c => LinearAlgebra.eigvals(b) for (c, b) in blocks(d))
373+
end
374+
375+
function LinearAlgebra.cond(d::DiagonalTensorMap, p::Real=2)
376+
return LinearAlgebra.cond(Diagonal(d.data), p)
377+
end
378+
342379
# matrix functions
343380
for f in
344381
(:exp, :cos, :sin, :tan, :cot, :cosh, :sinh, :tanh, :coth, :atan, :acot, :asinh, :sqrt,

test/diagonal.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,16 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
55
@testset "DiagonalTensor with domain $V" for V in diagspacelist
66
@timedtestset "Basic properties and algebra" begin
77
for T in (Float32, Float64, ComplexF32, ComplexF64, BigFloat)
8+
# constructors
89
t = @constinferred DiagonalTensorMap{T}(undef, V)
910
t = @constinferred DiagonalTensorMap(rand(T, reduceddim(V)), V)
11+
t2 = @constinferred DiagonalTensorMap{T}(undef, space(t))
12+
@test space(t2) == space(t)
13+
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2 V)
14+
t2 = @constinferred DiagonalTensorMap{T}(undef, domain(t))
15+
@test space(t2) == space(t)
16+
@test_throws ArgumentError DiagonalTensorMap{T}(undef, V^2)
17+
# properties
1018
@test @constinferred(hash(t)) == hash(deepcopy(t))
1119
@test scalartype(t) == T
1220
@test codomain(t) == ProductSpace(V)
@@ -135,6 +143,16 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
135143
@test u / t1 u / TensorMap(t1)
136144
@test t1 * u' TensorMap(t1) * u'
137145
@test t1 \ u' TensorMap(t1) \ u'
146+
147+
t3 = rand(Float64, V V^2)
148+
t4 = rand(ComplexF64, V V^2)
149+
@test t1 * t3 lmul!(t1, copy(t3))
150+
@test t2 * t4 lmul!(t2, copy(t4))
151+
152+
t3 = rand(Float64, V^2 V)
153+
t4 = rand(ComplexF64, V^2 V)
154+
@test t3 * t1 rmul!(copy(t3), t1)
155+
@test t4 * t2 rmul!(copy(t4), t2)
138156
end
139157
@timedtestset "Tensor contraction" begin
140158
d = DiagonalTensorMap(rand(ComplexF64, reduceddim(V)), V)
@@ -175,6 +193,12 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
175193
VdV2 = V2' * V2
176194
@test VdV2 one(VdV2)
177195
@test t2 * V2 V2 * D2
196+
197+
@test rank(D) rank(t)
198+
@test cond(D) cond(t)
199+
@test all(((s, t),) -> isapprox(s, t),
200+
zip(values(LinearAlgebra.eigvals(D)),
201+
values(LinearAlgebra.eigvals(t))))
178202
end
179203
@testset "leftorth with $alg" for alg in (TensorKit.QR(), TensorKit.QL())
180204
Q, R = @constinferred leftorth(t; alg=alg)
@@ -201,6 +225,12 @@ diagspacelist = ((ℂ^4)', ℂ[Z2Irrep](0 => 2, 1 => 3),
201225
VdV = Vᴴ * Vᴴ'
202226
@test VdV one(VdV)
203227
@test U * S * Vᴴ t
228+
229+
@test rank(S) rank(t)
230+
@test cond(S) cond(t)
231+
@test all(((s, t),) -> isapprox(s, t),
232+
zip(values(LinearAlgebra.svdvals(S)),
233+
values(LinearAlgebra.svdvals(t))))
204234
end
205235
end
206236
end

0 commit comments

Comments
 (0)