|
17 | 17 | """
|
18 | 18 | $(SIGNATURES)
|
19 | 19 |
|
20 |
| -Internal function used for printing symbolic expressions. This function determines |
21 |
| -the degrees of symbols within a given expression, implementing a variation on |
22 |
| -degree lexicographic order. |
| 20 | +Get the degrees of symbols within a given expression. |
| 21 | +
|
| 22 | +This internal function is used to define the order of terms in a symbolic expression, |
| 23 | +which is a variation on degree lexicographic order. It is used for printing and |
| 24 | +by [`sorted_arguments`](@ref). |
| 25 | +
|
| 26 | +Returns a tuple of degree and lexicographically sorted *multiplier* ⇒ *power* pairs, |
| 27 | +where the *multiplier* is a tuple of the symbol optionally followed by its indices. |
| 28 | +For a sum expression, returns the `get_degrees()` result for term with the highest degree. |
| 29 | +
|
| 30 | +See also `monomial_lt` and `lexlt`. |
23 | 31 | """
|
24 | 32 | function get_degrees(expr)
|
| 33 | + degs_cache = Dict() |
| 34 | + res = _get_degrees(expr, degs_cache) |
| 35 | + if res isa AbstractVector |
| 36 | + return Tuple(res) |
| 37 | + else |
| 38 | + return res |
| 39 | + end |
| 40 | +end |
| 41 | + |
| 42 | +function _get_degrees(expr, degs_cache::AbstractDict) |
25 | 43 | if issym(expr)
|
26 |
| - ((Symbol(expr),) => 1,) |
| 44 | + return get!(() -> ((Symbol(expr),) => 1,), degs_cache, expr) |
27 | 45 | elseif iscall(expr)
|
28 |
| - op = operation(expr) |
29 |
| - args = sorted_arguments(expr) |
30 |
| - if op == (^) && args[2] isa Number |
31 |
| - return map(get_degrees(args[1])) do (base, pow) |
32 |
| - (base => pow * args[2]) |
33 |
| - end |
34 |
| - elseif op == (*) |
35 |
| - return mapreduce(get_degrees, |
36 |
| - (x,y)->(x...,y...,), args) |
37 |
| - elseif op == (+) |
38 |
| - ds = map(get_degrees, args) |
39 |
| - _, idx = findmax(x->sum(last.(x), init=0), ds) |
40 |
| - return ds[idx] |
41 |
| - elseif op == (getindex) |
42 |
| - return ((Symbol.(args)...,) => 1,) |
43 |
| - else |
44 |
| - return ((Symbol("zzzzzzz", hash(expr)),) => 1,) |
| 46 | + # operation-specific degree handling |
| 47 | + return _get_degrees(operation(expr), expr, degs_cache) |
| 48 | + else |
| 49 | + return () # skip numbers and unsupported expressions |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +# fallback for unsupported operation |
| 54 | +_get_degrees(::Any, expr, degs_cache) = |
| 55 | + ((Symbol("zzzzzzz", hash(expr)),) => 1,) |
| 56 | + |
| 57 | +_getindex_symbol(arr, i) = Symbol(arr[i]) |
| 58 | + |
| 59 | +function _get_degrees(::typeof(getindex), expr, degs_cache) |
| 60 | + args = arguments(expr) |
| 61 | + @inbounds return get!(() -> (ntuple(Base.Fix1(_getindex_symbol, args), length(args)) => 1,), |
| 62 | + degs_cache, expr) |
| 63 | +end |
| 64 | + |
| 65 | +function _get_degrees(::typeof(*), expr, degs_cache) |
| 66 | + args = arguments(expr) |
| 67 | + ds = sizehint!(Vector{Any}(), length(args)) |
| 68 | + for arg in args |
| 69 | + degs = _get_degrees(arg, degs_cache) |
| 70 | + append!(ds, degs) |
| 71 | + end |
| 72 | + return sort!(ds) |
| 73 | +end |
| 74 | + |
| 75 | +function _get_degrees(::typeof(+), expr, degs_cache) |
| 76 | + # among the terms find the best in terms of monomial_lt |
| 77 | + sel_degs = () |
| 78 | + sel_degsum = 0 |
| 79 | + for arg in arguments(expr) |
| 80 | + degs = _get_degrees(arg, degs_cache) |
| 81 | + degsum = sum(last, degs, init=0) |
| 82 | + if (sel_degs == ()) || (degsum > sel_degsum) || |
| 83 | + (degsum == sel_degsum && lexlt(degs, sel_degs)) |
| 84 | + sel_degs, sel_degsum = degs, degsum |
| 85 | + end |
| 86 | + end |
| 87 | + return sel_degs |
| 88 | +end |
| 89 | + |
| 90 | +function _get_degrees(::typeof(^), expr, degs_cache) |
| 91 | + base_expr, pow_expr = arguments(expr) |
| 92 | + if pow_expr isa Number |
| 93 | + @inbounds degs = map(_get_degrees(base_expr, degs_cache)) do (base, pow) |
| 94 | + (base => pow * pow_expr) |
| 95 | + end |
| 96 | + if pow_expr < 0 && length(degs) > 1 |
| 97 | + # fix the order after the powers were negated |
| 98 | + isa(degs, AbstractVector) || (degs = collect(degs)) |
| 99 | + sort!(degs) |
| 100 | + end |
| 101 | + return degs |
| 102 | + else |
| 103 | + # expression in the power argument is not supported |
| 104 | + return _get_degrees(nothing, expr, degs_cache) |
| 105 | + end |
| 106 | +end |
| 107 | + |
| 108 | +function _get_degrees(::typeof(/), expr, degs_cache) |
| 109 | + nom_expr, denom_expr = arguments(expr) |
| 110 | + if denom_expr isa Number # constant denominator |
| 111 | + return _get_degrees(nom_expr, degs_cache) |
| 112 | + elseif nom_expr isa Number # constant nominator |
| 113 | + @inbounds degs = map(_get_degrees(denom_expr, degs_cache)) do (base, pow) |
| 114 | + (base => -pow) |
45 | 115 | end
|
| 116 | + isa(degs, AbstractVector) || (degs = collect(degs)) |
| 117 | + return sort!(degs) |
46 | 118 | else
|
47 |
| - return () |
| 119 | + # TODO expressions in both nom and denom are not yet supported |
| 120 | + return _get_degrees(nothing, expr, degs_cache) |
48 | 121 | end
|
49 | 122 | end
|
50 | 123 |
|
51 | 124 | function monomial_lt(degs1, degs2)
|
52 | 125 | d1 = sum(last, degs1, init=0)
|
53 | 126 | d2 = sum(last, degs2, init=0)
|
54 |
| - d1 != d2 ? d1 < d2 : lexlt(degs1, degs2) |
| 127 | + d1 != d2 ? |
| 128 | + # lower absolute degree first, or if equal, positive degree first |
| 129 | + (abs(d1) < abs(d2) || abs(d1) == abs(d2) && d1 > d2) : |
| 130 | + lexlt(degs1, degs2) |
55 | 131 | end
|
56 | 132 |
|
57 | 133 | function lexlt(degs1, degs2)
|
58 |
| - for (a, b) in zip(degs1, degs2) |
59 |
| - if a[1] == b[1] && a[2] != b[2] |
60 |
| - return a[2] > b[2] |
61 |
| - elseif a[1] != b[1] |
62 |
| - return a < b |
| 134 | + for ((a_base, a_deg), (b_base, b_deg)) in zip(degs1, degs2) |
| 135 | + if a_base == b_base && a_deg != b_deg |
| 136 | + # same base, higher absolute degree first, positive degree first |
| 137 | + return abs(a_deg) > abs(b_deg) || abs(a_deg) == abs(b_deg) && a_deg > b_deg |
| 138 | + elseif a_base != b_base |
| 139 | + # lexicographic order for the base |
| 140 | + return a_base < b_base |
63 | 141 | end
|
64 | 142 | end
|
65 | 143 | return false # they are equal
|
|
0 commit comments