Skip to content

Commit 315a5c4

Browse files
Alexey Stukalovalyst
authored andcommitted
get_degrees(): refactor
* replace if op = ... with get_degrees(op, expr) for extendability * avoid sorted_arguments() within get_degrees() * reduce allocations by iterating of degs in efficient way * support /: const/expr and expr/const
1 parent bf0c83b commit 315a5c4

File tree

2 files changed

+75
-19
lines changed

2 files changed

+75
-19
lines changed

src/ordering.jl

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,82 @@ See also `monomial_lt` and `lexlt`.
3131
"""
3232
function get_degrees(expr)
3333
if issym(expr)
34-
((Symbol(expr),) => 1,)
34+
return ((Symbol(expr),) => 1,)
3535
elseif iscall(expr)
36-
op = operation(expr)
37-
args = sorted_arguments(expr)
38-
if op == (^) && args[2] isa Number
39-
return map(get_degrees(args[1])) do (base, pow)
40-
(base => pow * args[2])
41-
end
42-
elseif op == (*)
43-
return mapreduce(get_degrees,
44-
(x,y)->(x...,y...,), args)
45-
elseif op == (+)
46-
ds = map(get_degrees, args)
47-
_, idx = findmax(x->sum(last, x, init=0), ds)
48-
return ds[idx]
49-
elseif op == (getindex)
50-
return (Tuple(map(Symbol, args)) => 1,)
51-
else
52-
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
36+
# operation-specific degree handling
37+
return get_degrees(operation(expr), expr)
38+
else
39+
return () # skip numbers and unsupported expressions
40+
end
41+
end
42+
43+
# fallback for unsupported operation
44+
get_degrees(::Any, expr) =
45+
((Symbol("zzzzzzz", hash(expr)),) => 1,)
46+
47+
_getindex_symbol(arr, i) = Symbol(arr[i])
48+
49+
function get_degrees(::typeof(getindex), expr)
50+
args = arguments(expr)
51+
@inbounds return (ntuple(Base.Fix1(_getindex_symbol, args), length(args)) => 1,)
52+
end
53+
54+
function get_degrees(::typeof(*), expr)
55+
args = arguments(expr)
56+
ds = sizehint!(Vector{Any}(), length(args))
57+
for arg in args
58+
degs = get_degrees(arg)
59+
append!(ds, degs)
60+
end
61+
return sort!(ds)
62+
end
63+
64+
function get_degrees(::typeof(+), expr)
65+
# among the terms find the best in terms of monomial_lt
66+
sel_degs = ()
67+
sel_degsum = 0
68+
for arg in arguments(expr)
69+
degs = get_degrees(arg)
70+
degsum = sum(last, degs, init=0)
71+
if (sel_degs == ()) || (degsum > sel_degsum) ||
72+
(degsum == sel_degsum && lexlt(degs, sel_degs))
73+
sel_degs, sel_degsum = degs, degsum
74+
end
75+
end
76+
return sel_degs
77+
end
78+
79+
function get_degrees(::typeof(^), expr)
80+
base_expr, pow_expr = arguments(expr)
81+
if pow_expr isa Number
82+
@inbounds degs = map(get_degrees(base_expr)) do (base, pow)
83+
(base => pow * pow_expr)
84+
end
85+
if pow_expr < 0 && length(degs) > 1
86+
# fix the order after the powers were negated
87+
isa(degs, AbstractVector) || (degs = collect(degs))
88+
sort!(degs)
89+
end
90+
return degs
91+
else
92+
# expression in the power argument is not supported
93+
return get_degrees(nothing, expr)
94+
end
95+
end
96+
97+
function get_degrees(::typeof(/), expr)
98+
nom_expr, denom_expr = arguments(expr)
99+
if denom_expr isa Number # constant denominator
100+
return get_degrees(nom_expr)
101+
elseif nom_expr isa Number # constant nominator
102+
@inbounds degs = map(get_degrees(denom_expr)) do (base, pow)
103+
(base => -pow)
53104
end
105+
isa(degs, AbstractVector) || (degs = collect(degs))
106+
return sort!(degs)
54107
else
55-
return ()
108+
# TODO expressions in both nom and denom are not yet supported
109+
return get_degrees(nothing, expr)
56110
end
57111
end
58112

test/basics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ end
221221
b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2)
222222
@test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]"
223223
@test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)"
224+
# test negative powers sorting
225+
@test repr((b3^2)^(-2) + a^(-3) + (c*d1)^(-2)) == "1 / (b[3]^4) + 1 / ((c^2)*(d[1]^2)) + 1 / (a^3)"
224226
end
225227

226228
@testset "inspect" begin

0 commit comments

Comments
 (0)