Skip to content
This repository was archived by the owner on Mar 23, 2025. It is now read-only.

Commit d6b5a1f

Browse files
authored
Improve contractions involving scalar-like tensors (#57)
* Add special code path for contractions involving scalar-like tensors
1 parent a05eca8 commit d6b5a1f

File tree

13 files changed

+335
-137
lines changed

13 files changed

+335
-137
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
4-
version = "0.1.19"
4+
version = "0.1.20"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/NDTensors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ include("deprecated.jl")
6666
# A global timer used with TimerOutputs.jl
6767
#
6868

69-
const GLOBAL_TIMER = TimerOutput()
69+
const timer = TimerOutput()
7070

7171
#####################################
7272
# Optional TBLIS contraction backend

src/blocksparse/blocksparsetensor.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ function permutedims!!(R::BlockSparseTensor{ElR,N},
611611
T::BlockSparseTensor{ElT,N},
612612
perm::NTuple{N,Int},
613613
f::Function=(r,t)->t) where {ElR,ElT,N}
614+
@timeit_debug timer "block sparse permutedims!!" begin
614615
bofsRR = blockoffsets(R)
615616
bofsT = blockoffsets(T)
616617

@@ -643,6 +644,7 @@ function permutedims!!(R::BlockSparseTensor{ElR,N},
643644

644645
permutedims!(R, T, perm, f)
645646
return R
647+
end
646648
end
647649

648650
# Version where it is known that R has the same blocks
@@ -781,6 +783,7 @@ function contraction_output(T1::TensorT1,
781783
labelsT2,
782784
labelsR) where {TensorT1<:BlockSparseTensor,
783785
TensorT2<:BlockSparseTensor}
786+
784787
indsR = contract_inds(inds(T1),labelsT1,inds(T2),labelsT2,labelsR)
785788
TensorR = contraction_output_type(TensorT1,TensorT2,typeof(indsR))
786789
blockoffsetsR,contraction_plan = contract_blockoffsets(blockoffsets(T1),inds(T1),labelsT1,
@@ -795,9 +798,11 @@ function contract(T1::BlockSparseTensor{<:Any,N1},
795798
T2::BlockSparseTensor{<:Any,N2},
796799
labelsT2,
797800
labelsR = contract_labels(labelsT1,labelsT2)) where {N1,N2}
801+
@timeit_debug timer "Block sparse contract" begin
798802
R,contraction_plan = contraction_output(T1,labelsT1,T2,labelsT2,labelsR)
799803
R = contract!(R,labelsR,T1,labelsT1,T2,labelsT2,contraction_plan)
800804
return R
805+
end
801806
end
802807

803808
function contract!(R::BlockSparseTensor{ElR, NR},

src/blocksparse/combiner.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11

2-
function contract(T::BlockSparseTensor,
3-
labelsT,
4-
C::CombinerTensor,
5-
labelsC)
2+
function contract(T::BlockSparseTensor, labelsT,
3+
C::CombinerTensor, labelsC)
4+
@timeit_debug timer "Block sparse (un)combiner" begin
65
# Get the label marking the combined index
76
# By convention the combined index is the first one
87
# TODO: consider storing the location of the combined
@@ -40,16 +39,14 @@ function contract(T::BlockSparseTensor,
4039
Ruc = uncombine(T,indsRuc,cpos_in_labelsRc,blockperm(C),blockcomb(C))
4140
return Ruc
4241
end
42+
end
4343
end
4444

45-
contract(C::CombinerTensor,
46-
labelsC,
47-
T::BlockSparseTensor,
48-
labelsT) = contract(T,labelsT,C,labelsC)
45+
contract(C::CombinerTensor, labelsC, T::BlockSparseTensor, labelsT) =
46+
contract(T,labelsT,C,labelsC)
4947

5048
# Special case when no indices are combined
51-
contract(T::BlockSparseTensor,
52-
labelsT,
53-
C::CombinerTensor{<:Any,0},
54-
labelsC) = copy(T)
49+
# XXX: no copy
50+
contract(T::BlockSparseTensor, labelsT,
51+
C::CombinerTensor{<:Any,0}, labelsC) = copy(T)
5552

src/blocksparse/linearalgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ function LinearAlgebra.svd(T::BlockSparseMatrix{ElT};
3636

3737
truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff)
3838

39+
@timeit_debug timer "block sparse svd" begin
3940
Us = Vector{DenseTensor{ElT, 2}}(undef, nnzblocks(T))
4041
Ss = Vector{DiagTensor{real(ElT), 2}}(undef, nnzblocks(T))
4142
Vs = Vector{DenseTensor{ElT, 2}}(undef, nnzblocks(T))
@@ -180,6 +181,7 @@ function LinearAlgebra.svd(T::BlockSparseMatrix{ElT};
180181
end
181182

182183
return U,S,V,Spectrum(d,truncerr)
184+
end # @timeit_debug
183185
end
184186

185187
_eigen_eltypes(T::Hermitian{ElT,<:BlockSparseMatrix{ElT}}) where {ElT} = real(ElT), ElT

src/combiner.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ data(::Combiner) = error("Combiner storage has no data")
1717
blockperm(C::Combiner) = C.perm
1818
blockcomb(C::Combiner) = C.comb
1919

20-
Base.eltype(::Type{<:Combiner}) = Number
20+
eltype(::Type{<:Combiner}) = Number
2121

22-
Base.eltype(::Combiner) = eltype(Combiner)
22+
eltype(::Combiner) = eltype(Combiner)
2323

24-
Base.promote_rule(::Type{<:Combiner},
24+
promote_rule(::Type{<:Combiner},
2525
StorageT::Type{<:Dense}) = StorageT
2626

2727
#
@@ -36,7 +36,7 @@ uncombinedinds(T::CombinerTensor) = popfirst(inds(T))
3636
blockperm(C::CombinerTensor) = blockperm(store(C))
3737
blockcomb(C::CombinerTensor) = blockcomb(store(C))
3838

39-
Base.conj(T::CombinerTensor; always_copy = false) = T
39+
conj(T::CombinerTensor; always_copy = false) = T
4040

4141
function contraction_output(::TensorT1,
4242
::TensorT2,
@@ -117,16 +117,12 @@ function contract!!(R::Tensor{<:Number,NR},
117117
return contract!!(R,labelsR,T2,labelsT2,T1,labelsT1)
118118
end
119119

120-
function Base.show(io::IO,
121-
mime::MIME"text/plain",
122-
S::Combiner)
120+
function show(io::IO, mime::MIME"text/plain", S::Combiner)
123121
println(io, "Permutation of blocks: ", S.perm)
124122
println(io, "Combination of blocks: ", S.comb)
125123
end
126124

127-
function Base.show(io::IO,
128-
mime::MIME"text/plain",
129-
T::CombinerTensor)
125+
function show(io::IO, mime::MIME"text/plain", T::CombinerTensor)
130126
summary(io, T)
131127
println(io)
132128
show(io, mime, store(T))

src/contraction_logic.jl

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,33 +41,62 @@ function _contract_inds!(Ris,
4141
T2is,
4242
T2labels::Labels{N2},
4343
Rlabels::Labels{NR}) where {N1,N2,NR}
44-
ncont = 0
45-
for i in T1labels
46-
i < 0 && (ncont += 1)
47-
end
48-
IndT = promote_type(eltype(T1is), eltype(T2is))
49-
u = 1
50-
# TODO: use Rlabels, don't assume ncon convention
51-
for i1 1:N1
52-
if T1labels[i1] > 0
53-
Ris[u] = T1is[i1]
54-
u += 1
55-
else
56-
# This is to check that T1is and T2is
57-
# can contract
58-
i2 = findfirst(==(T1labels[i1]),T2labels)
59-
dir(T1is[i1]) == -dir(T2is[i2]) || error("Attempting to contract index:\n\n$(T1is[i1])\nwith index:\n\n$(T2is[i2])\nIndices must have opposite directions to contract.")
44+
for n in 1:NR
45+
Rlabel = @inbounds Rlabels[n]
46+
found = false
47+
for n1 in 1:N1
48+
if Rlabel == @inbounds T1labels[n1]
49+
@inbounds Ris[n] = @inbounds T1is[n1]
50+
found = true
51+
break
52+
end
6053
end
61-
end
62-
for i2 1:N2
63-
if T2labels[i2] > 0
64-
Ris[u] = T2is[i2]
65-
u += 1
54+
if !found
55+
for n2 in 1:N2
56+
if Rlabel == @inbounds T2labels[n2]
57+
@inbounds Ris[n] = @inbounds T2is[n2]
58+
break
59+
end
60+
end
6661
end
6762
end
6863
return nothing
6964
end
7065

66+
# Old version that doesn't take into account Rlabels
67+
#function _contract_inds!(Ris,
68+
# T1is,
69+
# T1labels::Labels{N1},
70+
# T2is,
71+
# T2labels::Labels{N2},
72+
# Rlabels::Labels{NR}) where {N1,N2,NR}
73+
# ncont = 0
74+
# for i in T1labels
75+
# i < 0 && (ncont += 1)
76+
# end
77+
# IndT = promote_type(eltype(T1is), eltype(T2is))
78+
# u = 1
79+
# # TODO: use Rlabels, don't assume ncon convention
80+
# for i1 ∈ 1:N1
81+
# if T1labels[i1] > 0
82+
# Ris[u] = T1is[i1]
83+
# u += 1
84+
# else
85+
# # This is to check that T1is and T2is
86+
# # can contract
87+
# i2 = findfirst(==(T1labels[i1]),T2labels)
88+
# dir(T1is[i1]) == -dir(T2is[i2]) || error("Attempting to contract index:\n\n$(T1is[i1])\nwith index:\n\n$(T2is[i2])\nIndices must have opposite directions to contract.")
89+
# end
90+
# end
91+
# for i2 ∈ 1:N2
92+
# if T2labels[i2] > 0
93+
# Ris[u] = T2is[i2]
94+
# u += 1
95+
# end
96+
# end
97+
# return nothing
98+
#end
99+
71100
function contract_inds(T1is,
72101
T1labels::Labels{N1},
73102
T2is,

0 commit comments

Comments
 (0)