Skip to content

Commit de5fba9

Browse files
authored
improvements to flattening (#103)
1 parent 47ecdb8 commit de5fba9

File tree

5 files changed

+39
-8
lines changed

5 files changed

+39
-8
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# News
22

3+
## v0.4.7 - 2025-02-25
4+
5+
- Improvements to default flattening of expressions having nested sums and scalings.
6+
37
## v0.4.6 - 2025-01-18
48

59
- Migrate `express` functionality and representation types to QuantumInterface.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuantumSymbolics"
22
uuid = "efa7fd63-0460-4890-beb7-be1bbdfbaeae"
33
authors = ["QuantumSymbolics.jl contributors"]
4-
version = "0.4.6"
4+
version = "0.4.7"
55

66
[deps]
77
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"

src/QSymbolicsBase/basic_ops_homogeneous.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ julia> k₁ + k₂
8585
_arguments_precomputed
8686
end
8787
function SAdd{S}(d) where S
88+
isempty(d) && return SZero{S}()
8889
terms = [c*obj for (obj,c) in d]
8990
length(d)==1 ? first(terms) : SAdd{S}(d,Set(terms),terms)
9091
end
@@ -99,7 +100,7 @@ function Base.:(+)(x::Symbolic{T}, xs::Vararg{Symbolic{T}, N}) where {T<:QObj, N
99100
xs = collect(xs)
100101
f = first(xs)
101102
nonzero_terms = filter!(x->!iszero(x),xs)
102-
isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SScaled{T}))
103+
isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SAdd{T}, SScaled{T}))
103104
end
104105
basis(x::SAdd) = basis(first(x.dict).first)
105106

src/QSymbolicsBase/utils.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,30 @@ function countmap(samples) # A simpler version of StatsBase.countmap, because St
3434
counts
3535
end
3636

37-
function countmap_flatten(samples, flattenhead)
37+
function countmap_flatten(samples, flattenadd, flattenmul)
3838
counts = Dict{Any,Any}()
3939
for s in samples
40-
if isexpr(s) && s isa flattenhead # TODO Could you use the TermInterface `operation` here instead of `flattenhead`?
40+
if s isa flattenadd
41+
for (term,coef) in pairs(s.dict)
42+
counts[term] = get(counts, term, 0)+coef
43+
end
44+
elseif s isa flattenmul
4145
coef, term = arguments(s)
42-
counts[term] = get(counts, term, 0)+coef
46+
if term isa flattenadd
47+
for (_term,_coef) in pairs(term.dict)
48+
counts[_term] = get(counts, _term, 0)+coef*_coef
49+
end
50+
else
51+
counts[term] = get(counts, term, 0)+coef
52+
end
4353
else
4454
counts[s] = get(counts, s, 0)+1
4555
end
4656
end
57+
for (term,coef) in pairs(counts)
58+
if iszero(coef)===true # iszero might return symbolic expressions instead of true/false # TODO make into a proper function like isdefinitelyzero, see whether upstream Symbolics has it
59+
delete!(counts, term)
60+
end
61+
end
4762
counts
48-
end
63+
end

test/test_sym_expressions.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,24 @@
22
@test +(Z1) == Z1
33
@test +(Z) == Z
44
@test isequal(Z1 - Z2, Z1 + (-Z2))
5-
@test_broken isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1)
6-
@test_broken isequal(Z1 - 2*Z2 + 2*X1, Z1 + 2*(-Z2+X1))
5+
@test isequal(Z1 - 2*Z2 + 2*X1, -2*Z2 + Z1 + 2*X1)
6+
@test isequal(Z1 - 2*Z2 + 2*X1, Z1 + 2*(-Z2+X1))
77

88
state1 = XBasisState(1, SpinBasis(1//2))
99
state2 = XBasisState(1, SpinBasis(1//2))
1010
state3 = XBasisState(2, SpinBasis(1//2))
1111

1212
@test isequal(state1, state2)
1313
@test !isequal(state1, state3)
14+
15+
@op A
16+
@op B
17+
18+
@test isequal(A+B+A, 2A+B)
19+
@test isequal(A+B-A, B)
20+
@test A+B-A === B
21+
@test 0*A === A-A-2A+2A
22+
@test isequal(A+B-A+B, 2B)
23+
@test isequal(2A-3B+2(A+B), 4A-B)
24+
@test isequal(2A-3B+2(A+2B), 4A+B)
1425
end

0 commit comments

Comments
 (0)