Skip to content

Commit 837effd

Browse files
Merge pull request #732 from alyst/refactor_get_degrees
Refactor get_degrees() to make it faster
2 parents f474cfa + d5b469d commit 837effd

File tree

3 files changed

+125
-29
lines changed

3 files changed

+125
-29
lines changed

benchmark/benchmarks.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Random
55

66
SUITE = BenchmarkGroup()
77

8-
@syms a b c d; Random.seed!(123);
8+
@syms a b c d x y[1:3] z[1:2, 1:2]; Random.seed!(123);
99

1010
let r = @rule(~x => ~x), rs = RuleSet([r]),
1111
acr = @rule(~x::is_literal_number + ~y => ~y)
@@ -67,7 +67,19 @@ let r = @rule(~x => ~x), rs = RuleSet([r]),
6767
subs_expr = (sin(a+b) + cos(b+c)) * (sin(b+c) + cos(c+a)) * (sin(c+a) + cos(a+b))
6868
end
6969

70+
overhead["get_degrees"] = BenchmarkGroup()
7071

72+
let y1 = term(getindex, y, 1, type=Number),
73+
y2 = term(getindex, y, 2, type=Number),
74+
y3 = term(getindex, y, 3, type=Number),
75+
z11 = term(getindex, z, 1, 1, type=Number),
76+
z12 = term(getindex, z, 1, 2, type=Number),
77+
z23 = term(getindex, z, 2, 3, type=Number)
78+
79+
# create a relatively large polynomial
80+
large_poly = SymbolicUtils.expand((x^2 + 2y1 + 3z12 + y2*z23 + x*y1*z12 - x^2*z12 + x*z11 + y3 + y2 + z23 + 1)^8)
81+
overhead["get_degrees"]["large_poly"] = @benchmarkable SymbolicUtils.get_degrees($large_poly)
82+
end
7183
end
7284

7385
let

src/ordering.jl

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,49 +17,127 @@
1717
"""
1818
$(SIGNATURES)
1919
20-
Internal function used for printing symbolic expressions. This function determines
21-
the degrees of symbols within a given expression, implementing a variation on
22-
degree lexicographic order.
20+
Get the degrees of symbols within a given expression.
21+
22+
This internal function is used to define the order of terms in a symbolic expression,
23+
which is a variation on degree lexicographic order. It is used for printing and
24+
by [`sorted_arguments`](@ref).
25+
26+
Returns a tuple of degree and lexicographically sorted *multiplier* ⇒ *power* pairs,
27+
where the *multiplier* is a tuple of the symbol optionally followed by its indices.
28+
For a sum expression, returns the `get_degrees()` result for term with the highest degree.
29+
30+
See also `monomial_lt` and `lexlt`.
2331
"""
2432
function get_degrees(expr)
33+
degs_cache = Dict()
34+
res = _get_degrees(expr, degs_cache)
35+
if res isa AbstractVector
36+
return Tuple(res)
37+
else
38+
return res
39+
end
40+
end
41+
42+
function _get_degrees(expr, degs_cache::AbstractDict)
2543
if issym(expr)
26-
((Symbol(expr),) => 1,)
44+
return get!(() -> ((Symbol(expr),) => 1,), degs_cache, expr)
2745
elseif iscall(expr)
28-
op = operation(expr)
29-
args = sorted_arguments(expr)
30-
if op == (^) && args[2] isa Number
31-
return map(get_degrees(args[1])) do (base, pow)
32-
(base => pow * args[2])
33-
end
34-
elseif op == (*)
35-
return mapreduce(get_degrees,
36-
(x,y)->(x...,y...,), args)
37-
elseif op == (+)
38-
ds = map(get_degrees, args)
39-
_, idx = findmax(x->sum(last.(x), init=0), ds)
40-
return ds[idx]
41-
elseif op == (getindex)
42-
return ((Symbol.(args)...,) => 1,)
43-
else
44-
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
46+
# operation-specific degree handling
47+
return _get_degrees(operation(expr), expr, degs_cache)
48+
else
49+
return () # skip numbers and unsupported expressions
50+
end
51+
end
52+
53+
# fallback for unsupported operation
54+
_get_degrees(::Any, expr, degs_cache) =
55+
((Symbol("zzzzzzz", hash(expr)),) => 1,)
56+
57+
_getindex_symbol(arr, i) = Symbol(arr[i])
58+
59+
function _get_degrees(::typeof(getindex), expr, degs_cache)
60+
args = arguments(expr)
61+
@inbounds return get!(() -> (ntuple(Base.Fix1(_getindex_symbol, args), length(args)) => 1,),
62+
degs_cache, expr)
63+
end
64+
65+
function _get_degrees(::typeof(*), expr, degs_cache)
66+
args = arguments(expr)
67+
ds = sizehint!(Vector{Any}(), length(args))
68+
for arg in args
69+
degs = _get_degrees(arg, degs_cache)
70+
append!(ds, degs)
71+
end
72+
return sort!(ds)
73+
end
74+
75+
function _get_degrees(::typeof(+), expr, degs_cache)
76+
# among the terms find the best in terms of monomial_lt
77+
sel_degs = ()
78+
sel_degsum = 0
79+
for arg in arguments(expr)
80+
degs = _get_degrees(arg, degs_cache)
81+
degsum = sum(last, degs, init=0)
82+
if (sel_degs == ()) || (degsum > sel_degsum) ||
83+
(degsum == sel_degsum && lexlt(degs, sel_degs))
84+
sel_degs, sel_degsum = degs, degsum
85+
end
86+
end
87+
return sel_degs
88+
end
89+
90+
function _get_degrees(::typeof(^), expr, degs_cache)
91+
base_expr, pow_expr = arguments(expr)
92+
if pow_expr isa Number
93+
@inbounds degs = map(_get_degrees(base_expr, degs_cache)) do (base, pow)
94+
(base => pow * pow_expr)
95+
end
96+
if pow_expr < 0 && length(degs) > 1
97+
# fix the order after the powers were negated
98+
isa(degs, AbstractVector) || (degs = collect(degs))
99+
sort!(degs)
100+
end
101+
return degs
102+
else
103+
# expression in the power argument is not supported
104+
return _get_degrees(nothing, expr, degs_cache)
105+
end
106+
end
107+
108+
function _get_degrees(::typeof(/), expr, degs_cache)
109+
nom_expr, denom_expr = arguments(expr)
110+
if denom_expr isa Number # constant denominator
111+
return _get_degrees(nom_expr, degs_cache)
112+
elseif nom_expr isa Number # constant nominator
113+
@inbounds degs = map(_get_degrees(denom_expr, degs_cache)) do (base, pow)
114+
(base => -pow)
45115
end
116+
isa(degs, AbstractVector) || (degs = collect(degs))
117+
return sort!(degs)
46118
else
47-
return ()
119+
# TODO expressions in both nom and denom are not yet supported
120+
return _get_degrees(nothing, expr, degs_cache)
48121
end
49122
end
50123

51124
function monomial_lt(degs1, degs2)
52125
d1 = sum(last, degs1, init=0)
53126
d2 = sum(last, degs2, init=0)
54-
d1 != d2 ? d1 < d2 : lexlt(degs1, degs2)
127+
d1 != d2 ?
128+
# lower absolute degree first, or if equal, positive degree first
129+
(abs(d1) < abs(d2) || abs(d1) == abs(d2) && d1 > d2) :
130+
lexlt(degs1, degs2)
55131
end
56132

57133
function lexlt(degs1, degs2)
58-
for (a, b) in zip(degs1, degs2)
59-
if a[1] == b[1] && a[2] != b[2]
60-
return a[2] > b[2]
61-
elseif a[1] != b[1]
62-
return a < b
134+
for ((a_base, a_deg), (b_base, b_deg)) in zip(degs1, degs2)
135+
if a_base == b_base && a_deg != b_deg
136+
# same base, higher absolute degree first, positive degree first
137+
return abs(a_deg) > abs(b_deg) || abs(a_deg) == abs(b_deg) && a_deg > b_deg
138+
elseif a_base != b_base
139+
# lexicographic order for the base
140+
return a_base < b_base
63141
end
64142
end
65143
return false # they are equal

test/basics.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ end
221221
b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2)
222222
@test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]"
223223
@test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)"
224+
# test negative powers sorting
225+
@test repr((b3^2)^(-2) + a^(-3) + (c*d1)^(-2)) == "1 / (a^3) + 1 / (b[3]^4) + 1 / ((c^2)*(d[1]^2))"
226+
227+
# test that the "x^2 + y^-1 + sin(a)^3.5 + 2t + 1//1" expression from Symbolics.jl/build_targets.jl is properly sorted
228+
@syms x1 y1 a1 t1
229+
@test repr(x1^2 + y1^-1 + sin(a1)^3.5 + 2t1 + 1//1) == "(1//1) + 2t1 + 1 / y1 + x1^2 + sin(a1)^3.5"
224230
end
225231

226232
@testset "inspect" begin

0 commit comments

Comments
 (0)