diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c86d510cc..951d10cc6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,6 +20,7 @@ jobs: version: - 'min' - '1' + fail-fast: false steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/Project.toml b/Project.toml index 8cca44eaf..1be23be55 100644 --- a/Project.toml +++ b/Project.toml @@ -13,11 +13,16 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" +ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -26,7 +31,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" +WeakCacheSets = "d30d5f5c-d141-4870-aa07-aabb0f5fe7d5" WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1" [weakdeps] @@ -47,11 +52,16 @@ ConstructionBase = "1.5.7" DataStructures = "0.18" DocStringExtensions = "0.8, 0.9" DynamicPolynomials = "0.5, 0.6" +EnumX = "1.0.5" ExproniconLite = "0.10.14" LabelledArrays = "1.5" +MacroTools = "0.5.16" +Moshi = "0.3.6" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1.1.2" OhMyThreads = "0.7" +ReadOnlyArrays = "0.2.0" +ReadOnlyDicts = "1.0.0" ReverseDiff = "1" RuntimeGeneratedFunctions = "0.5.13" Setfield = "0.7, 0.8, 1" @@ -61,7 +71,6 @@ SymbolicIndexingInterface = "0.3" TaskLocalValues = "0.1.2" TermInterface = "2.0" TimerOutputs = "0.5" -Unityper = "0.1.2" WeakValueDicts = "0.1.0" julia = "1.10" @@ -82,3 +91,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "SafeTestsets", "Test", "Zygote", "OhMyThreads", "RuntimeGeneratedFunctions"] + +[sources] +WeakCacheSets = {url="https://github.com/JuliaCollections/WeakCacheSets.jl"} +Moshi = {url="https://github.com/AayushSabharwal/Moshi.jl", rev="as/mutable-adt"} diff --git a/bench.jl b/bench.jl new file mode 100644 index 000000000..21488ca5d --- /dev/null +++ b/bench.jl @@ -0,0 +1,9 @@ +using SymbolicUtils, BenchmarkTools + +@syms a b c d e f g h i +ex = (f + ((((g*(c^2)*(e^2)) / d - e*h*(c^2)) / b + (-c*e*f*g) / d + c*e*i) / + (i + ((c*e*g) / d - c*h) / b + (-f*g) / d) - c*e) / b + + ((g*(f^2)) / d + ((-c*e*f*g) / d + c*f*h) / b - f*i) / + (i + ((c*e*g) / d - c*h) / b + (-f*g) / d)) / d + +@benchmark SymbolicUtils.fraction_iszero($ex) diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 81648c0b1..3bc0e616c 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -1 +1,5 @@ [deps] + +[sources] +WeakCacheSets = {url="https://github.com/JuliaCollections/WeakCacheSets.jl"} +Moshi = {url="https://github.com/AayushSabharwal/Moshi.jl", rev="as/mutable-adt"} diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index eae631897..18ba8bac0 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -94,6 +94,8 @@ let (-f*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d)) / d pform["simplify_fractions"] = @benchmarkable simplify_fractions($ex) pform["iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($ex) - pform["isone"] = @benchmarkable SymbolicUtils.fraction_isone($o) + pform["isone"] = @benchmarkable SymbolicUtils.fraction_isone($ex) + pform["isone:noop"] = @benchmarkable SymbolicUtils.fraction_isone($o) + pform["iszero:noop"] = @benchmarkable SymbolicUtils.fraction_iszero($o) pform["easy_iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($((b*(h + (-e*g) / d)) / b + (e*g) / d - h)) end diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 43ee94247..e525cf509 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -20,7 +20,7 @@ r1 = @rule sin(2(~x)) => 2sin(~x)*cos(~x) r1(sin(2z)) # output -2sin(z)*cos(z) +2cos(z)*sin(z) ``` The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent_ (`@rule matcher => consequent`). If an expression matches the matcher pattern, it is rewritten to the consequent pattern. `@rule` returns a callable object that applies the rule to an expression. @@ -41,7 +41,7 @@ Slot variable (matcher) is not necessary a single variable r1(sin(2*(w-z))) # output -2cos(w - z)*sin(w - z) +2sin(w - z)*cos(w - z) ``` but it must be a single expression @@ -61,7 +61,7 @@ r2 = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y); r2(sin(α+β)) # output -sin(β)*cos(α) + cos(β)*sin(α) +cos(β)*sin(α) + sin(β)*cos(α) ``` If you want to match a variable number of subexpressions at once, you will need a **segment variable**. `~~xs` in the following example is a segment variable: @@ -71,10 +71,10 @@ If you want to match a variable number of subexpressions at once, you will need @rule(+(~~xs) => ~~xs)(x + y + z) # output -3-element view(::SymbolicUtils.SmallVec{Any, Vector{Any}}, 1:3) with eltype Any: - z - y +3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, SymbolicUtils.SmallVec{Any, Vector{Any}}}, 1:3) with eltype Any: x + y + z ``` `~~xs` is a vector of subexpressions matched. You can use it to construct something more useful: diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 3c930f66a..c065c9336 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -7,7 +7,12 @@ using DocStringExtensions export @syms, term, showraw, hasmetadata, getmetadata, setmetadata -using Unityper +using Moshi.Data: @data +import Moshi.Data as MData +using Moshi.Match: @match +using ReadOnlyArrays +using ReadOnlyDicts +using EnumX: @enumx using TermInterface using DataStructures using Setfield @@ -23,6 +28,9 @@ import ArrayInterface import ExproniconLite as EL import TaskLocalValues: TaskLocalValue using WeakValueDicts: WeakValueDict +using WeakCacheSets: WeakCacheSet, getkey! +using Base: RefValue +import MacroTools # include("WeakCacheSets.jl") diff --git a/src/cache.jl b/src/cache.jl index d65ada029..b1f0b6c37 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -27,7 +27,7 @@ The key stored in the cache for a particular value. Returns a `SymbolicKey` for # can't dispatch because `BasicSymbolic` isn't defined here function get_cache_key(x) if x isa BasicSymbolic - id = x.id[] + id = x.id if id === nothing return CacheSentinel() end diff --git a/src/code.jl b/src/code.jl index 0550e3cde..0239861e3 100644 --- a/src/code.jl +++ b/src/code.jl @@ -784,7 +784,7 @@ struct CSEState """ A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it. """ - visited::IdDict{Any, Any} + visited::IdDict{Union{SymbolicUtils.IDType, AbstractArray, Tuple}, BasicSymbolic} """ Integer counter, used to generate unique names for intermediate variables. """ @@ -870,8 +870,9 @@ function cse! end indextype(::AbstractSparseArray{Tv, Ti}) where {Tv, Ti} = Ti -function cse!(expr::Symbolic, state::CSEState) - get!(state.visited, expr) do + +function cse!(expr::BasicSymbolic, state::CSEState) + get!(state.visited, expr.id) do iscall(expr) || return expr op = operation(expr) diff --git a/src/inspect.jl b/src/inspect.jl index ab3951725..224e2d57b 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -5,19 +5,20 @@ function AbstractTrees.nodevalue(x::Symbolic) iscall(x) ? operation(x) : isexpr(x) ? head(x) : x end -function AbstractTrees.nodevalue(x::BasicSymbolic) +function AbstractTrees.nodevalue(x::BSImpl.Type) + T = nameof(MData.variant_type(x)) str = if !iscall(x) - string(exprtype(x), "(", x, ")") + string(T, "(", x, ")") elseif isadd(x) - string(exprtype(x), - (scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) + string(T, + (variant=string(x.variant), scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) elseif ismul(x) - string(exprtype(x), - (scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) + string(T, + (variant=string(x.variant), scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) elseif isdiv(x) || ispow(x) - string(exprtype(x)) + string(T) else - string(exprtype(x),"{", operation(x), "}") + string(T,"{", operation(x), "}") end if inspect_metadata[] && !isnothing(metadata(x)) diff --git a/src/matchers.jl b/src/matchers.jl index 99b76ea18..3371a993a 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -10,7 +10,8 @@ function matcher(val::Any) # if val is a call (like an operation) creates a term matcher or term matcher with defslot if iscall(val) # if has two arguments and one of them is a DefSlot, create a term matcher with defslot - if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val)) + args = parent(arguments(val)) + if length(args) == 2 && any(x -> isa(x, DefSlot), args) return defslot_term_matcher_constructor(val) # else return a normal term matcher else @@ -35,7 +36,9 @@ function matcher(slot::Slot) end # elseif the first element of data matches the slot predicate, add it to bindings and call next elseif slot.predicate(car(data)) - next(assoc(bindings, slot.name, car(data)), 1) + rest = car(data) + binds = assoc(bindings, slot.name, rest) + next(binds, 1) end end end @@ -104,32 +107,34 @@ function matcher(segment::Segment) end function term_matcher_constructor(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) + matchers = vcat([matcher(operation(term))], map(matcher, parent(arguments(term)))) - function term_matcher(success, data, bindings) - !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing + let matchers = matchers + function term_matcher(success, data, bindings) + !islist(data) && return nothing # if data is not a list, return nothing + !iscall(car(data)) && return nothing # if first element is not a call, return nothing - function loop(term, bindings′, matchers′) # Get it to compile faster - if !islist(matchers′) - if !islist(term) - return success(bindings′, 1) + function loop(term, bindings′, matchers′) # Get it to compile faster + if !islist(matchers′) + if !islist(term) + return success(bindings′, 1) + end + return nothing end - return nothing - end - car(matchers′)(term, bindings′) do b, n - loop(drop_n(term, n), b, cdr(matchers′)) + car(matchers′)(term, bindings′) do b, n + loop(drop_n(term, n), b, cdr(matchers′)) + end + # explenation of above 3 lines: + # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) + # <------ next(b,n) ----------------------------> + # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list + # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term + # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses + # the length of the list, is considered empty end - # explenation of above 3 lines: - # car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′) - # <------ next(b,n) ----------------------------> - # car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list - # Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term - # Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses - # the length of the list, is considered empty - end - loop(car(data), bindings, matchers) # Try to eat exactly one term + loop(car(data), bindings, matchers) # Try to eat exactly one term + end end end @@ -146,7 +151,7 @@ end # calls the success function like term_matcher would do function defslot_term_matcher_constructor(term) - a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments + a = parent(arguments(term)) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term diff --git a/src/methods.jl b/src/methods.jl index 462cc1552..70f9613f6 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -25,16 +25,6 @@ const previously_declared_for = Set([]) const basic_monadic = [-, +] const basic_diadic = [+, -, *, /, //, \, ^] -#################### SafeReal ######################### -export SafeReal, LiteralReal - -# ideally the relationship should be the other way around -abstract type SafeReal <: Real end - -################### LiteralReal ####################### - -abstract type LiteralReal <: Real end - ####################################################### assert_like(f, T) = nothing @@ -101,13 +91,25 @@ macro number_methods(T, rhs1, rhs2, options=nothing) end @number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics) -@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics) +@number_methods(BasicSymbolic{LiteralReal}, term(f, a), term(f, a, b), onlybasics) for f in vcat(diadic, [+, -, *, \, /, ^]) @eval promote_symtype(::$(typeof(f)), T::Type{<:Number}, S::Type{<:Number}) = promote_type(T, S) - for R in [SafeReal, LiteralReal] + @eval promote_symtype(::$(typeof(f)), + T::Type{<:Rational}, + S::Type{Integer}) = Rational + @eval promote_symtype(::$(typeof(f)), + T::Type{Integer}, + S::Type{<:Rational}) = Rational + @eval promote_symtype(::$(typeof(f)), + T::Type{<:Complex{<:Rational}}, + S::Type{Integer}) = Complex{Rational} + @eval promote_symtype(::$(typeof(f)), + T::Type{Integer}, + S::Type{<:Complex{<:Rational}}) = Complex{Rational} + for R in [SafeRealImpl, LiteralRealImpl] @eval function promote_symtype(::$(typeof(f)), T::Type{<:$R}, S::Type{<:Real}) @@ -153,8 +155,8 @@ end promote_symtype(::Any, T) = promote_type(T, Real) for f in monadic @eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real) - @eval promote_symtype(::$(typeof(f)), T::Type{<:SafeReal}) = SafeReal - @eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralReal}) = LiteralReal + @eval promote_symtype(::$(typeof(f)), T::Type{<:SafeRealImpl}) = SafeReal + @eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralRealImpl}) = LiteralReal end Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a) diff --git a/src/ordering.jl b/src/ordering.jl index 1b72d1cb2..8de132082 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -89,7 +89,7 @@ end function _get_degrees(::typeof(^), expr, degs_cache) base_expr, pow_expr = arguments(expr) - if pow_expr isa Number + if pow_expr isa Real @inbounds degs = map(_get_degrees(base_expr, degs_cache)) do (base, pow) (base => pow * pow_expr) end diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..b095df043 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -26,19 +26,19 @@ PolyForm(sin((x+y)^2)) #=> sin((x+y)^2) PolyForm(sin((x+y)^2), recurse=true) #=> sin((x^2 + (2x)y + y^2)) ``` """ -struct PolyForm{T} <: Symbolic{T} +struct PolyForm <: Symbolic{Real} p::MP.AbstractPolynomialLike pvar2sym::Bijection{Any,Any} # @polyvar x --> @sym x etc. sym2term::Dict{BasicSymbolic,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...)) metadata - function (::Type{PolyForm{T}})(p, d1, d2, m=nothing) where {T} + function PolyForm(p, d1, d2, m=nothing) p isa Number && return p p isa MP.AbstractPolynomialLike && MP.isconstant(p) && return convert(Number, p) - new{T}(p, d1, d2, m) + new(p, d1, d2, m) end end -@number_methods(PolyForm{<:Number}, term(f, a), term(f, a, b)) +@number_methods(PolyForm, term(f, a), term(f, a, b)) Base.hash(p::PolyForm, u::UInt64) = xor(hash(p.p, u), trunc(UInt, 0xbabacacababacaca)) Base.isequal(x::PolyForm, y::PolyForm) = isequal(x.p, y.p) @@ -56,7 +56,7 @@ function get_pvar2sym() PVAR2SYM[] = WeakRef(d) return d else - return v + return v::Bijections.Bijection{Any, Any, Dict{Any, Any}, Dict{Any, Any}} end end @@ -67,7 +67,7 @@ function get_sym2term() SYM2TERM[] = WeakRef(d) return d else - return v + return v::Dict{BasicSymbolic, Any} end end @@ -80,7 +80,7 @@ end # forward gcd -PF = :(PolyForm{promote_symtype(/, symtype(x), symtype(y))}) +PF = :(PolyForm) const FriendlyCoeffType = Union{Integer, Rational} @eval begin Base.div(x::PolyForm, y::PolyForm) = $PF(div(x.p, y.p), mix_dicts(x, y)...) @@ -94,6 +94,9 @@ end _isone(p::PolyForm) = isone(p.p) +maybe_float(::Type{T}, x) where {T <: Integer} = x +maybe_float(::Type, x) = x isa Number && !(x isa Rational) ? float(x) : x + function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) if x isa Number return x @@ -103,15 +106,16 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) end op = operation(x) - args = arguments(x) - - local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse) + args = parent(arguments(x)) - if typeof(+) <: Fs && op == (+) + local_polyize = let pvar2sym = pvar2sym, sym2term = sym2term, vtype = vtype, pow = pow, Fs = Fs, recurse = recurse, T = symtype(x) + f(y) = maybe_float(T, polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse)) + end + if (+) isa Fs && op === (+) return sum(local_polyize, args) - elseif typeof(*) <: Fs && op == (*) + elseif (*) isa Fs && op === (*) return prod(local_polyize, args) - elseif typeof(^) <: Fs && op == (^) && args[2] isa Integer && args[2] > 0 + elseif (^) isa Fs && op === (^) && args[2] isa Integer && args[2] > 0 @assert length(args) == 2 return local_polyize(args[1])^(args[2]) else @@ -120,7 +124,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) y = if recurse maketerm(typeof(x), op, - map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args), + map(a->PolyForm(a; pvar2sym, sym2term, vtype, Fs, recurse), args), metadata(x)) else x @@ -129,9 +133,9 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) name = Symbol(string(op), "_", hash(y)) @label lookup - sym = Sym{symtype(x)}(name) + sym = Sym{Number}(name) if haskey(sym2term, sym) - if isequal(sym2term[sym][1], x) + if isequal(sym2term[sym][1], x)::Bool return local_polyize(sym) else # hash collision name = Symbol(name, "_") @@ -153,10 +157,10 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) end end -function PolyForm(x, +function PolyForm(x; pvar2sym=get_pvar2sym(), sym2term=get_sym2term(), - vtype=DynamicPolynomials.Variable{ DynamicPolynomials.Commutative{DynamicPolynomials.CreationOrder},DynamicPolynomials.Graded{MP.LexOrder}}; + vtype=DynamicPolynomials.Variable{DynamicPolynomials.Commutative{DynamicPolynomials.CreationOrder}, DynamicPolynomials.Graded{MP.LexOrder}}, Fs = Union{typeof(+), typeof(*), typeof(^)}, recurse=false, metadata=metadata(x)) @@ -167,7 +171,7 @@ function PolyForm(x, # Polyize and return a PolyForm p = polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) - PolyForm{symtype(x)}(p, pvar2sym, sym2term, metadata) + PolyForm(p, pvar2sym, sym2term, metadata) end isexpr(x::Type{<:PolyForm}) = true @@ -186,7 +190,7 @@ end head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) -function TermInterface.arguments(x::PolyForm{T}) where {T} +function TermInterface.arguments(x::PolyForm) function is_var(v) MP.nterms(v) == 1 && @@ -213,10 +217,10 @@ function TermInterface.arguments(x::PolyForm{T}) where {T} m = MP.monomial(t) if !isone(c) - [c, (unstable_pow(resolve(v), pow) + [c, (^(resolve(v), pow) for (v, pow) in MP.powers(m) if !iszero(pow))...] else - [unstable_pow(resolve(v), pow) + [^(resolve(v), pow) for (v, pow) in MP.powers(m) if !iszero(pow)] end elseif MP.nterms(x.p) == 0 @@ -227,7 +231,7 @@ function TermInterface.arguments(x::PolyForm{T}) where {T} convert(Number, t) : (is_var(t) ? resolve(t) : - PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] + PolyForm(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] end end children(x::PolyForm) = arguments(x) @@ -245,7 +249,7 @@ Expand expressions by distributing multiplication over addition, e.g., multivariate polynomials implementation. `variable_type` can be any subtype of `MultivariatePolynomials.AbstractVariable`. """ -expand(expr) = unpolyize(PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true)) +expand(expr) = unpolyize(PolyForm(expr; Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true)) function unpolyize(x) # we need a special maketerm here because the default one used in Postwalk will call @@ -265,9 +269,9 @@ function polyform_factors(d, pvar2sym, sym2term) if ispow(x) && x.exp isa Integer && x.exp > 0 # here we do want to recurse one level, that's why it's wrong to just # use Fs = Union{typeof(+), typeof(*)} here. - Pow(PolyForm(x.base, pvar2sym, sym2term), x.exp) + Pow(PolyForm(x.base; pvar2sym, sym2term), x.exp) else - PolyForm(x, pvar2sym, sym2term) + PolyForm(x; pvar2sym, sym2term) end end @@ -283,23 +287,7 @@ function simplify_div(d) if all(_isone, ds) return isempty(ns) ? 1 : simplify_fractions(_mul(ns)) else - Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds))) - end -end - -#add_divs(x::Div, y::Div) = (x.num * y.den + y.num * x.den) / (x.den * y.den) -#add_divs(x::Div, y) = (x.num + y * x.den) / x.den -#add_divs(x, y::Div) = (x * y.den + y.num) / y.den -#add_divs(x, y) = x + y -function add_divs(x, y) - if isdiv(x) && isdiv(y) - return (x.num * y.den + y.num * x.den) / (x.den * y.den) - elseif isdiv(x) - return (x.num + y * x.den) / x.den - elseif isdiv(y) - return (x * y.den + y.num) / y.den - else - x + y + Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds)), false) end end @@ -344,14 +332,42 @@ end function add_with_div(x, flatten=true) (!iscall(x) || operation(x) != (+)) && return x - aa = arguments(x) - !any(a->isdiv(a), aa) && return x # no rewrite necessary - - divs = filter(a->isdiv(a), aa) - nondivs = filter(a->!(isdiv(a)), aa) - nds = isempty(nondivs) ? 0 : +(nondivs...) - d = reduce(quick_cancel∘add_divs, divs) - flatten ? quick_cancel(add_divs(d, nds)) : d + nds + aa = parent(arguments(x)) + !any(isdiv, aa) && return x # no rewrite necessary + + # find and multiply all denominators + dens = ArgsT() + for a in aa + isdiv(a) || continue + push!(dens, a.den) + end + den = mul_worker(dens) + + # add all numerators + div_idx = 1 + nums = ArgsT() + for a in aa + # if it is a division, we don't want to multiply the numerator by + # its own denominator, so temporarily overwrite the index in `dens` + # that is the denominator of this term (tracked by `div_idx`), multiply + # and voila! numerator. Remember to reset `dens` at the end. + if isdiv(a) + _den = dens[div_idx] + dens[div_idx] = a.num + _num = mul_worker(dens) + dens[div_idx] = _den + div_idx += 1 + else + _num = den * a + end + push!(nums, _num) + end + num = add_worker(nums) + + if flatten + num, den = quick_cancel(num, den) + end + return num / den end """ flatten_fractions(x) @@ -369,7 +385,10 @@ end function fraction_iszero(x) !iscall(x) && return _iszero(x) + old_hc = ENABLE_HASHCONSING[] + ENABLE_HASHCONSING[] = false ff = flatten_fractions(x) + ENABLE_HASHCONSING[] = old_hc # fast path and then slow path any(_iszero, numerators(ff)) || any(_iszero∘expand, numerators(ff)) @@ -412,16 +431,23 @@ it wouldn't simplify `(x^2 + 15 - 8x) / (x - 5)` to `(x - 3)`. But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`. Has optimized processes for `Mul` and `Pow` terms. """ -function quick_cancel(d) - if ispow(d) && isdiv(d.base) - return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp)) - elseif ismul(d) && any(isdiv, arguments(d)) - return prod(arguments(d)) - elseif isdiv(d) - num, den = quick_cancel(d.num, d.den) - return Div(num, den) - else - return d +quick_cancel(d) = d +function quick_cancel(d::BSImpl.Type{T}) where {T} + @match d begin + BSImpl.Pow(; base, exp) => begin + base isa BSImpl.Type || return d + MData.isa_variant(base, BSImpl.Div) || return d + n, d = quick_cancel((base.num ^ exp), (base.den ^ exp)) + return Div{T}(n, d, false) + end + BSImpl.AddOrMul(; variant) && if variant == AddMulVariant.MUL && any(isdiv, arguments(d)) end => begin + return mul_worker(arguments(d)) + end + BSImpl.Div(; num, den) => begin + num, den = quick_cancel(num, den) + return Div(num, den, false) + end + _ => return d end end @@ -480,7 +506,7 @@ function quick_mul(x, y) error("Can't reach") end - return Mul(symtype(x), x.coeff, d), 1 + return Mul{symtype(x)}(x.coeff, d), 1 else return x, y end @@ -501,7 +527,7 @@ function quick_mulpow(x, y) den = Pow{symtype(y)}(y.base, y.exp-d[y.base]) delete!(d, y.base) end - return Mul(symtype(x), x.coeff, d), den + return Mul{symtype(x)}(x.coeff, d), den else return x, y end @@ -510,7 +536,7 @@ end # Double mul case function quick_mulmul(x, y) num_dict, den_dict = _merge_div(x.dict, y.dict) - Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict) + Mul{symtype(x)}(x.coeff, num_dict), Mul{symtype(y)}(y.coeff, den_dict) end function _merge_div(ndict, ddict) diff --git a/src/rewriters.jl b/src/rewriters.jl index 78efc7ee8..bf16262ff 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -30,7 +30,6 @@ rewriters. """ module Rewriters -using SymbolicUtils: @timer using TermInterface import SymbolicUtils: iscall, operation, arguments, sorted_arguments, metadata, node_count, _promote_symtype @@ -61,15 +60,15 @@ end If(f, x) = IfElse(f, x, Empty()) -struct Chain - rws - stop_on_match::Bool +mutable struct Chain{Cs} + const rws::Cs + const stop_on_match::Bool end Chain(rws) = Chain(rws, false) function (rw::Chain)(x) for f in rw.rws - y = @timer cached_repr(f) f(x) + y = f(x) if rw.stop_on_match && !isnothing(y) && !isequal(y, x) return y end @@ -79,8 +78,24 @@ function (rw::Chain)(x) end end return x +end +@generated function (rw::Chain{<:NTuple{N, Any}})(x) where {N} + quote + Base.@nexprs $N i -> begin + f = rw.rws[i] + y = f(x) + if rw.stop_on_match && y !== nothing && !isequal(x, y) + return y + end + if y !== nothing + x = y + end + end + return x + end end + instrument(c::Chain, f) = Chain(map(x->instrument(x,f), c.rws)) struct RestartedChain{Cs} @@ -91,7 +106,7 @@ instrument(c::RestartedChain, f) = RestartedChain(map(x->instrument(x,f), c.rws) function (rw::RestartedChain)(x) for f in rw.rws - y = @timer cached_repr(f) f(x) + y = f(x) if y !== nothing return Chain(rw.rws)(y) end @@ -103,7 +118,7 @@ end quote Base.@nexprs $N i->begin let f = rw.rws[i] - y = @timer cached_repr(repr(f)) f(x) + y = f(x) if y !== nothing return Chain(rw.rws)(y) end @@ -122,11 +137,11 @@ instrument(x::Fixpoint, f) = Fixpoint(instrument(x.rw, f)) function (rw::Fixpoint)(x) f = rw.rw - y = @timer cached_repr(f) f(x) - while x !== y && !isequal(x, y) + y = f(x) + while (x !== y) && !isequal(x, y) y === nothing && return x x = y - y = @timer cached_repr(f) f(x) + y = f(x) end return x end @@ -150,7 +165,7 @@ instrument(x::FixpointNoCycle, f) = Fixpoint(instrument(x.rw, f)) function (rw::FixpointNoCycle)(x) f = rw.rw push!(rw.hist, hash(x)) - y = @timer cached_repr(f) f(x) + y = f(x) while x !== y && hash(x) ∉ rw.hist if y === nothing empty!(rw.hist) @@ -158,7 +173,7 @@ function (rw::FixpointNoCycle)(x) end push!(rw.hist, y) x = y - y = @timer cached_repr(f) f(x) + y = f(x) end empty!(rw.hist) return x @@ -205,7 +220,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} if iscall(x) x = p.maketerm(typeof(x), operation(x), map(PassThrough(p), - arguments(x)), metadata(x)) + parent(arguments(x))), metadata(x)) end return ord === :post ? p.rw(x) : x @@ -221,14 +236,14 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} x = p.rw(x) end if iscall(x) - _args = map(arguments(x)) do arg + _args = map(parent(arguments(x))) do arg if node_count(arg) > p.thread_cutoff Threads.@spawn p(arg) else p(arg) end end - args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) + args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, parent(arguments(x))) t = p.maketerm(typeof(x), operation(x), args, metadata(x)) end return ord === :post ? p.rw(t) : t diff --git a/src/rule.jl b/src/rule.jl index d20598b2c..ead70d95e 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -204,7 +204,7 @@ function (r::Rule)(term) try # n == 1 means that exactly one term of the input (term,) was matched - success(bindings, n) = n == 1 ? (@timer "RHS" rhs(assoc(bindings, :MATCH, term))) : nothing + success(bindings, n) = n == 1 ? (rhs(assoc(bindings, :MATCH, term))) : nothing return r.matcher(success, (term,), EMPTY_IMMUTABLE_DICT) catch err throw(RuleRewriteError(r, term)) @@ -465,15 +465,37 @@ function (acr::ACRule)(term) T = symtype(term) args = arguments(term) + is_full_perm = acr.arity == length(args) + if is_full_perm + args_buf = copy(parent(args)) + else + args_buf = ArgsT(@view args[1:acr.arity]) + end itr = acr.sets(eachindex(args), acr.arity) for inds in itr - result = r(Term{T}(f, @views args[inds])) + for (i, ind) in enumerate(inds) + args_buf[i] = args[ind] + end + # this is temporary and only constructed so the rule can + # try and match it - no need to hashcons it. + tempterm = BSImpl.Term{T}(f, args_buf; unsafe = true) + # this term will be hashconsed regardless + result = r(tempterm) if result !== nothing # Assumption: inds are unique - length(args) == length(inds) && return result - return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], metadata(term)) + is_full_perm && return result + inds_set = BitSet(inds) + full_args_buf = ArgsT(@view args[1:(length(args)-acr.arity+1)]) + idx = 1 + for i in eachindex(args) + i in inds_set && continue + full_args_buf[idx] = args[i] + idx += 1 + end + full_args_buf[idx] = result + return maketerm(typeof(term), f, full_args_buf, metadata(term)) end end end @@ -493,57 +515,4 @@ getdepth(::Any) = typemax(Int) print(io, msg) end -function timerewrite(f) - if !TIMER_OUTPUTS - error("timerewrite must be called after enabling " * - "TIMER_OUTPUTS in the main file of this package") - end - reset_timer!() - being_timed[] = true - x = f() - being_timed[] = false - print_timer() - println() - x -end - - -""" - @timerewrite expr - -If `expr` calls `simplify` or a `RuleSet` object, track the amount of time -it spent on applying each rule and pretty print the timing. - -This uses [TimerOutputs.jl](https://github.com/KristofferC/TimerOutputs.jl). - -## Example: - -```julia - -julia> expr = foldr(*, rand([a,b,c,d], 100)) -(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28) - -julia> @timerewrite simplify(expr) - ──────────────────────────────────────────────────────────────────────────────────────────────── - Time Allocations - ────────────────────── ─────────────────────── - Tot / % measured: 340ms / 15.3% 92.2MiB / 10.8% - - Section ncalls time %tot avg alloc %tot avg - ──────────────────────────────────────────────────────────────────────────────────────────────── - ACRule((~y) ^ ~n * ~y => (~y) ^ (~n ... 667 11.1ms 21.3% 16.7μs 2.66MiB 26.8% 4.08KiB - RHS 92 277μs 0.53% 3.01μs 14.4KiB 0.14% 160B - ACRule((~x) ^ ~n * (~x) ^ ~m => (~x)... 575 7.63ms 14.6% 13.3μs 1.83MiB 18.4% 3.26KiB - (*)(~(~(x::!issortedₑ))) => sort_arg... 831 6.31ms 12.1% 7.59μs 738KiB 7.26% 910B - RHS 164 3.03ms 5.81% 18.5μs 250KiB 2.46% 1.52KiB - ... - ... - ──────────────────────────────────────────────────────────────────────────────────────────────── -(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28) -``` -""" -macro timerewrite(expr) - :(timerewrite(()->$(esc(expr)))) -end - Base.@deprecate RuleSet(x) Postwalk(Chain(x)) diff --git a/src/simplify.jl b/src/simplify.jl index 68fe78f83..8a63d00ec 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -13,7 +13,7 @@ Simplify an expression (`x`) by applying `rewriter` until there are no changes. By default, simplify will assume denominators are not zero and allow cancellation in fractions. Pass `simplify_fractions=false` to prevent this. """ -function simplify(x; +@inline function simplify(x; expand=false, polynorm=nothing, threaded=false, diff --git a/src/simplify_rules.jl b/src/simplify_rules.jl index a612036cb..c5310c370 100644 --- a/src/simplify_rules.jl +++ b/src/simplify_rules.jl @@ -6,180 +6,167 @@ the argument to the predicate satisfies `iscall` and `operation(x) == f` """ is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f) -let - CANONICALIZE_PLUS = [ - @rule(~x::isnotflat(+) => flatten_term(+, ~x)) - @rule(~x::needs_sorting(+) => sort_args(+, ~x)) - @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b) - - @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) - - @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)) - @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x)) - @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...)) - - @ordered_acrule((~z::_iszero + ~x) => ~x) - @rule(+(~x) => ~x) - ] - - PLUS_DISTRIBUTE = [ - @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) - @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)) - ] - - CANONICALIZE_TIMES = [ - @rule(~x::isnotflat(*) => flatten_term(*, ~x)) - @rule(~x::needs_sorting(*) => sort_args(*, ~x)) - - @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) - @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) - - @acrule((~y)^(~n) * ~y => (~y)^(~n+1)) - - @ordered_acrule((~z::_isone * ~x) => ~x) - @ordered_acrule((~z::_iszero * ~x) => ~z) - @rule(*(~x) => ~x) - ] - - MUL_DISTRIBUTE = @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m)) - - CANONICALIZE_POW = [ - @rule(^(*(~~x), ~y::_isinteger) => *(map(a->pow(a, ~y), ~~x)...)) - @rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q))) - @rule(^(~x, ~z::_iszero) => 1) - @rule(^(~x, ~z::_isone) => ~x) - @rule(inv(~x) => 1/(~x)) - ] - - POW_RULES = [ - @rule(^(~x::_isone, ~z) => 1) - ] - - ASSORTED_RULES = [ - @rule(identity(~x) => ~x) - @rule(-(~x) => -1*~x) - @rule(-(~x, ~y) => ~x + -1(~y)) - @rule(~x::_isone \ ~y => ~y) - @rule(~x \ ~y => ~y / (~x)) - @rule(one(~x) => one(symtype(~x))) - @rule(zero(~x) => zero(symtype(~x))) - @rule(conj(~x::_isreal) => ~x) - @rule(real(~x::_isreal) => ~x) - @rule(imag(~x::_isreal) => zero(symtype(~x))) - @rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z) - @rule(ifelse(~x, ~y, ~y) => ~y) - ] - - TRIG_EXP_RULES = [ - @acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)) - @acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)) - @acrule(sin(~x)^2 + cos(~x)^2 => one(~x)) - @acrule(sin(~x)^2 + -1 => -1*cos(~x)^2) - @acrule(cos(~x)^2 + -1 => -1*sin(~x)^2) - - @acrule(cos(~x)^2 + -1*sin(~x)^2 => cos(2 * ~x)) - @acrule(sin(~x)^2 + -1*cos(~x)^2 => -cos(2 * ~x)) - @acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2) - - @acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x)) - @acrule(-1*tan(~x)^2 + sec(~x)^2 => one(~x)) - @acrule(tan(~x)^2 + 1 => sec(~x)^2) - @acrule(sec(~x)^2 + -1 => tan(~x)^2) - - @acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x)) - @acrule(cot(~x)^2 + 1 => csc(~x)^2) - @acrule(csc(~x)^2 + -1 => cot(~x)^2) - - @acrule(cosh(~x)^2 + -1*sinh(~x)^2 => one(~x)) - @acrule(cosh(~x)^2 + -1 => sinh(~x)^2) - @acrule(sinh(~x)^2 + 1 => cosh(~x)^2) - - @acrule(cosh(~x)^2 + sinh(~x)^2 => cosh(2 * ~x)) - @acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2) - - @acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y)) - @rule(exp(~x)^(~y) => exp(~x * ~y)) - ] - - BOOLEAN_RULES = [ - @rule((true | (~x)) => true) - @rule(((~x) | true) => true) - @rule((false | (~x)) => ~x) - @rule(((~x) | false) => ~x) - @rule((true & (~x)) => ~x) - @rule(((~x) & true) => ~x) - @rule((false & (~x)) => false) - @rule(((~x) & false) => false) - - @rule(!(~x) & ~x => false) - @rule(~x & !(~x) => false) - @rule(!(~x) | ~x => true) - @rule(~x | !(~x) => true) - @rule(xor(~x, !(~x)) => true) - @rule(xor(~x, ~x) => false) - - @rule(~x == ~x => true) - @rule(~x != ~x => false) - @rule(~x < ~x => false) - @rule(~x > ~x => false) - - # simplify terms with no symbolic arguments - # e.g. this simplifies term(isodd, 3, type=Bool) - # or term(!, false) - @rule((~f)(~x::is_literal_number) => (~f)(~x)) - # and this simplifies any binary comparison operator - @rule((~f)(~x::is_literal_number, ~y::is_literal_number) => (~f)(~x, ~y)) - ] - - function number_simplifier() - rule_tree = [If(iscall, Chain(ASSORTED_RULES)), - If(x -> !isadd(x) && is_operation(+)(x), - Chain(CANONICALIZE_PLUS)), - If(is_operation(+), Chain(PLUS_DISTRIBUTE)), # This would be useful even if isadd - If(x -> !ismul(x) && is_operation(*)(x), - Chain(CANONICALIZE_TIMES)), - If(is_operation(*), MUL_DISTRIBUTE), - If(x -> !ispow(x) && is_operation(^)(x), - Chain(CANONICALIZE_POW)), - If(is_operation(^), Chain(POW_RULES)), - ] |> RestartedChain - - rule_tree - end - - trig_exp_simplifier(;kw...) = Chain(TRIG_EXP_RULES) - - bool_simplifier() = Chain(BOOLEAN_RULES) - - global default_simplifier - global serial_simplifier - global threaded_simplifier - global serial_simplifier - global serial_expand_simplifier - - function default_simplifier(; kw...) - IfElse(has_trig_exp, - Postwalk(IfElse(x->symtype(x) <: Number, - Chain((number_simplifier(), - trig_exp_simplifier())), - If(x->symtype(x) <: Bool, - bool_simplifier())) - ; kw...), - Postwalk(Chain((If(x->symtype(x) <: Number, - number_simplifier()), - If(x->symtype(x) <: Bool, - bool_simplifier()))) - ; kw...)) - end - - # reduce overhead of simplify by defining these as constant - serial_simplifier = If(iscall, Fixpoint(default_simplifier())) - - threaded_simplifier(cutoff) = Fixpoint(default_simplifier(threaded=true, - thread_cutoff=cutoff)) - - serial_expand_simplifier = If(iscall, - Fixpoint(Chain((expand, - Fixpoint(default_simplifier()))))) - +const CANONICALIZE_PLUS = ( + @rule(~x::isnotflat(+) => flatten_term(+, ~x)), + @rule(~x::needs_sorting(+) => sort_args(+, ~x)), + @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b), + + @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)), + + @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x)), + @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x)), + @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...)), + + @ordered_acrule((~z::_iszero + ~x) => ~x), + @rule(+(~x) => ~x), +) + +const PLUS_DISTRIBUTE = ( + @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)), + @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)), +) + +const CANONICALIZE_TIMES = ( + @rule(~x::isnotflat(*) => flatten_term(*, ~x)), + @rule(~x::needs_sorting(*) => sort_args(*, ~x)), + + @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b), + @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)), + + @acrule((~y)^(~n) * ~y => (~y)^(~n+1)), + + @ordered_acrule((~z::_isone * ~x) => ~x), + @ordered_acrule((~z::_iszero * ~x) => ~z), + @rule(*(~x) => ~x), +) + +const MUL_DISTRIBUTE = @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m)) + +const CANONICALIZE_POW = ( + @rule(^(*(~~x), ~y::_isinteger) => *(map(a->pow(a, ~y), ~~x)...)), + @rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q))), + @rule(^(~x, ~z::_iszero) => 1), + @rule(^(~x, ~z::_isone) => ~x), + @rule(inv(~x) => 1/(~x)), +) + +const POW_RULES = ( + @rule(^(~x::_isone, ~z) => 1), +) + +const ASSORTED_RULES = ( + @rule(identity(~x) => ~x), + @rule(-(~x) => -1*~x), + @rule(-(~x, ~y) => ~x + -1(~y)), + @rule(~x::_isone \ ~y => ~y), + @rule(~x \ ~y => ~y / (~x)), + @rule(one(~x) => one(symtype(~x))), + @rule(zero(~x) => zero(symtype(~x))), + @rule(conj(~x::_isreal) => ~x), + @rule(real(~x::_isreal) => ~x), + @rule(imag(~x::_isreal) => zero(symtype(~x))), + @rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z), + @rule(ifelse(~x, ~y, ~y) => ~y), +) + +const TRIG_EXP_RULES = ( + @acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)), + @acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)), + @acrule(sin(~x)^2 + cos(~x)^2 => one(~x)), + @acrule(sin(~x)^2 + -1 => -1*cos(~x)^2), + @acrule(cos(~x)^2 + -1 => -1*sin(~x)^2), + + @acrule(cos(~x)^2 + -1*sin(~x)^2 => cos(2 * ~x)), + @acrule(sin(~x)^2 + -1*cos(~x)^2 => -cos(2 * ~x)), + @acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2), + + @acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x)), + @acrule(-1*tan(~x)^2 + sec(~x)^2 => one(~x)), + @acrule(tan(~x)^2 + 1 => sec(~x)^2), + @acrule(sec(~x)^2 + -1 => tan(~x)^2), + + @acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x)), + @acrule(cot(~x)^2 + 1 => csc(~x)^2), + @acrule(csc(~x)^2 + -1 => cot(~x)^2), + + @acrule(cosh(~x)^2 + -1*sinh(~x)^2 => one(~x)), + @acrule(cosh(~x)^2 + -1 => sinh(~x)^2), + @acrule(sinh(~x)^2 + 1 => cosh(~x)^2), + + @acrule(cosh(~x)^2 + sinh(~x)^2 => cosh(2 * ~x)), + @acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2), + + @acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y)), + @rule(exp(~x)^(~y) => exp(~x * ~y)), +) + +const BOOLEAN_RULES = ( + @rule((true | (~x)) => true), + @rule(((~x) | true) => true), + @rule((false | (~x)) => ~x), + @rule(((~x) | false) => ~x), + @rule((true & (~x)) => ~x), + @rule(((~x) & true) => ~x), + @rule((false & (~x)) => false), + @rule(((~x) & false) => false), + + @rule(!(~x) & ~x => false), + @rule(~x & !(~x) => false), + @rule(!(~x) | ~x => true), + @rule(~x | !(~x) => true), + @rule(xor(~x, !(~x)) => true), + @rule(xor(~x, ~x) => false), + + @rule(~x == ~x => true), + @rule(~x != ~x => false), + @rule(~x < ~x => false), + @rule(~x > ~x => false), + + # simplify terms with no symbolic arguments + # e.g. this simplifies term(isodd, 3, type=Bool) + # or term(!, false) + @rule((~f)(~x::is_literal_number) => (~f)(~x)), + # and this simplifies any binary comparison operator + @rule((~f)(~x::is_literal_number, ~y::is_literal_number) => (~f)(~x, ~y)), +) + +const NUMBER_SIMPLIFIER = RestartedChain(( + If(iscall, Chain(ASSORTED_RULES)), + If(x -> !isadd(x) && is_operation(+)(x), + Chain(CANONICALIZE_PLUS)), + If(is_operation(+), Chain(PLUS_DISTRIBUTE)), # This would be useful even if isadd + If(x -> !ismul(x) && is_operation(*)(x), + Chain(CANONICALIZE_TIMES)), + If(is_operation(*), MUL_DISTRIBUTE), + If(x -> !ispow(x) && is_operation(^)(x), + Chain(CANONICALIZE_POW)), + If(is_operation(^), Chain(POW_RULES)), +)) + +const TRIG_EXP_SIMPLIFIER = Chain(TRIG_EXP_RULES) + +const BOOLEAN_SIMPLIFIER = Chain(BOOLEAN_RULES) + + +function get_default_simplifier(; kw...) + IfElse(has_trig_exp, + Postwalk(IfElse(x->symtype(x) <: Number, + Chain((NUMBER_SIMPLIFIER, TRIG_EXP_SIMPLIFIER)), + If(x->symtype(x) <: Bool, BOOLEAN_SIMPLIFIER)) + ; kw...), + Postwalk(Chain((If(x->symtype(x) <: Number, + NUMBER_SIMPLIFIER), + If(x->symtype(x) <: Bool, + BOOLEAN_SIMPLIFIER))) + ; kw...)) end + +# reduce overhead of simplify by defining these as constant +const serial_simplifier = If(iscall, Fixpoint(get_default_simplifier())) + +threaded_simplifier(cutoff) = Fixpoint(get_default_simplifier(threaded=true, + thread_cutoff=cutoff)) + +const serial_expand_simplifier = If(iscall, + Fixpoint(Chain((expand, + Fixpoint(get_default_simplifier()))))) diff --git a/src/small_array.jl b/src/small_array.jl index 1adaa371c..469af54eb 100644 --- a/src/small_array.jl +++ b/src/small_array.jl @@ -183,3 +183,4 @@ end Base.any(f::Function, x::SmallVec) = any(f, x.data) Base.all(f::Function, x::SmallVec) = all(f, x.data) +Base.map(f, x::SmallVec{T, V}) where {T, V} = SmallVec{T,V}(map(f, x.data)) diff --git a/src/types.jl b/src/types.jl index bf141d34a..fd5987dbc 100644 --- a/src/types.jl +++ b/src/types.jl @@ -4,72 +4,144 @@ #-------------------- abstract type Symbolic{T} end +#################### SafeReal ######################### +export SafeReal, LiteralReal + +# ideally the relationship should be the other way around +abstract type SafeRealImpl <: Number end +const SafeReal = Union{SafeRealImpl, Real} +Base.one(::Type{SafeReal}) = true +Base.zero(::Type{SafeReal}) = false +Base.convert(::Type{<:SafeRealImpl}, x::Number) = convert(Real, x) + +################### LiteralReal ####################### + +abstract type LiteralRealImpl <: Number end +const LiteralReal = Union{LiteralRealImpl, Real} +Base.one(::Type{LiteralReal}) = true +Base.zero(::Type{LiteralReal}) = false +Base.convert(::Type{<:LiteralRealImpl}, x::Number) = convert(Real, x) + ### ### Uni-type design ### -@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV +struct Unknown end -const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} -const NO_METADATA = nothing +const MetadataT = Union{Base.ImmutableDict{DataType, Any}, Nothing} +const SmallV{T} = SmallVec{T, Vector{T}} +const ArgsT = SmallV{Any} +const ROArgsT = ReadOnlyVector{Any, ArgsT} +const ACDict{K, V} = Dict{K, V} +const ShapeVecT = SmallV{UnitRange{Int}} +const ShapeT = Union{Unknown, ShapeVecT} +const IdentT = Union{IDType, Nothing} -sdict(kv...) = Dict{Any, Any}(kv...) +""" + Enum used to differentiate between variants of `BasicSymbolicImpl.ACTerm`. +""" +@enumx AddMulVariant::Bool begin + ADD = false + MUL = true +end -using Base: RefValue -const SmallV{T} = SmallVec{T, Vector{T}} -const EMPTY_ARGS = SmallV{Any}() -const EMPTY_HASH = RefValue(UInt(0)) -const EMPTY_DICT = sdict() -const EMPTY_DICT_T = typeof(EMPTY_DICT) -const ENABLE_HASHCONSING = Ref(true) -const TID = Union{IDType, Nothing} -const DID = nothing - -@compactify show_methods=false begin - @abstract mutable struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA - id::RefValue{TID} = Ref{TID}(DID) - end - mutable struct Sym{T} <: BasicSymbolic{T} - name::Symbol = :OOF - end - mutable struct Term{T} <: BasicSymbolic{T} - f::Any = identity # base/num if Pow; issorted if Add/Dict - arguments::SmallV{Any} = EMPTY_ARGS - hash::RefValue{UInt} = EMPTY_HASH - hash2::RefValue{UInt} = EMPTY_HASH - end - mutable struct Mul{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - hash2::RefValue{UInt} = EMPTY_HASH - arguments::SmallV{Any} = EMPTY_ARGS - end - mutable struct Add{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - hash2::RefValue{UInt} = EMPTY_HASH - arguments::SmallV{Any} = EMPTY_ARGS - end - mutable struct Div{T} <: BasicSymbolic{T} - num::Any = 1 - den::Any = 1 - simplified::Bool = false - arguments::SmallV{Any} = EMPTY_ARGS - end - mutable struct Pow{T} <: BasicSymbolic{T} - base::Any = 1 - exp::Any = 1 - arguments::SmallV{Any} = EMPTY_ARGS +""" + $(TYPEDSIGNATURES) + +Check if `coeff` is the identity element for `ACTerm` variant `v`. +""" +function is_identity_coeff(v::AddMulVariant.T, coeff) + @match v begin + AddMulVariant.ADD => iszero(coeff) + AddMulVariant.MUL => isone(coeff) end end +""" + $(TYPEDSIGNATURES) + +Get the identity coefficient for `ACTerm` variant `v` as type `T`. +""" +function identity_coeff(v::AddMulVariant.T, T = Bool) + @match v begin + AddMulVariant.ADD => zero(T) + AddMulVariant.MUL => one(T) + end +end + +""" + $(TYPEDEF) + +Core ADT for `BasicSymbolic`. `hash` and `isequal` compare metadata. +""" +@data mutable BasicSymbolicImpl{T} <: Symbolic{T} begin + # struct Const{T} + # val::T + # id::RefValue{IdentT} + # end + struct Sym + const name::Symbol + const metadata::MetadataT + const shape::ShapeT + hash2::UInt + id::IdentT + end + struct Term + const f::Any + const args::ArgsT + const metadata::MetadataT + const shape::ShapeT + hash::UInt + hash2::UInt + id::IdentT + end + struct AddOrMul + const variant::AddMulVariant.T + const coeff::T + const dict::ACDict{Symbolic, T} + const metadata::MetadataT + const shape::ShapeT + const args::ArgsT + hash::UInt + hash2::UInt + id::IdentT + end + struct Div + const num::Any + const den::Any + # TODO: Keep or remove? + # Flag for whether this div is in the most simplified form we can compute. + # This being false doesn't mean no elimination is performed. Trivials such as + # constant factors can be eliminated. However, polynomial elimination may not + # have been performed yet. Typically used as an early-exit for simplification + # algorithms to not try to eliminate more. + const simplified::Bool + const metadata::MetadataT + const shape::ShapeT + hash2::UInt + id::IdentT + end + struct Pow + const base::Any + const exp::Any + const metadata::MetadataT + const shape::ShapeT + hash2::UInt + id::IdentT + end +end + +const BSImpl = BasicSymbolicImpl +const BasicSymbolic = BSImpl.Type + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end +function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic{<:AbstractArray}}) + ArraySymbolic() +end + """ $(TYPEDSIGNATURES) @@ -78,43 +150,67 @@ returning the input as-is. """ unwrap(x) = x -function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => TERM - Add => ADD - Mul => MUL - Div => DIV - Pow => POW - Sym => SYM - _ => error_on_type() - end +struct UnimplementedForVariantError <: Exception + method + variant end -# Same but different error messages -@noinline error_on_type() = error("Internal error: unreachable reached!") -@noinline error_sym() = error("Sym doesn't have an operation or any arguments!") -@noinline error_property(E, s) = error("$E doesn't have field $s") +function Base.showerror(io::IO, err::UnimplementedForVariantError) + print(io, """ + $(err.method) is not implemented for variant $(err.variant) of `BasicSymbolicImpl`. + """) +end + +""" + $(TYPEDSIGNATURES) + +Properties of `obj` that override any explicitly provided values in +`ConstructionBase.setproperties`. +""" +override_properties(obj::BSImpl.Type) = override_properties(MData.variant_type(obj)) + +function override_properties(obj::Type{<:BSImpl.Variant}) + @match obj begin + ::Type{<:BSImpl.Sym} => (; id = nothing, hash2 = 0) + ::Type{<:BSImpl.Term} => (; id = nothing, hash = 0, hash2 = 0) + ::Type{<:BSImpl.AddOrMul} => (; id = nothing, hash = 0, hash2 = 0) + ::Type{<:BSImpl.Div} => (; id = nothing, hash2 = 0) + ::Type{<:BSImpl.Pow} => (; id = nothing, hash2 = 0) + _ => throw(UnimplementedForVariantError(override_properties, obj)) + end +end -# We can think about bits later -# flags -const SIMPLIFIED = 0x01 << 0 +function ordered_override_properties(obj::Type{<:BSImpl.Variant}) + @match obj begin + ::Type{<:BSImpl.Sym} => (0, nothing) + ::Type{<:BSImpl.Term} => (0, 0, nothing) + ::Type{<:BSImpl.AddOrMul} => (ArgsT(), 0, 0, nothing) + ::Type{<:BSImpl.Div} => (0, nothing) + ::Type{<:BSImpl.Pow} => (0, nothing) + _ => throw(UnimplementedForVariantError(override_properties, obj)) + end +end -#@inline is_of_type(x::BasicSymbolic, type::UInt8) = (x.bitflags & type) != 0x00 -#@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) +function ConstructionBase.getproperties(obj::BSImpl.Type) + @match obj begin + BSImpl.Sym(; name, metadata, hash2, shape, id) => (; name, metadata, hash2, shape, id) + BSImpl.Term(; f, args, metadata, hash, hash2, shape, id) => (; f, args, metadata, hash, hash2, shape, id) + BSImpl.AddOrMul(; variant, coeff, dict, args, metadata, hash, hash2, shape, id) => (; variant, coeff, dict, args, metadata, hash, hash2, shape, id) + BSImpl.Div(; num, den, simplified, metadata, hash2, shape, id) => (; num, den, simplified, metadata, hash2, shape, id) + BSImpl.Pow(; base, exp, metadata, hash2, shape, id) => (; base, exp, metadata, hash2, shape, id) + end +end -function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T - nt = getproperties(obj) - nt_new = merge(nt, patch) - # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj::BasicSymbolic begin - Sym => Sym{T}(nt_new.name; nt_new...) - Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID)) - Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID)) - Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID)) - Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID)) - Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID)) - _ => Unityper.rt_constructor(obj){T}(;nt_new...) +function ConstructionBase.setproperties(obj::BSImpl.Type{T}, patch::NamedTuple)::BSImpl.Type{T} where {T} + props = getproperties(obj) + overrides = override_properties(obj) + # We only want to invalidate `args` if we're updating `coeff` or `dict`. + if isaddmul(obj) && (haskey(patch, :coeff) || haskey(patch, :dict)) + extras = (; args = ArgsT()) + else + extras = (;) end + hashcons(MData.variant_type(obj)(; props..., patch..., overrides..., extras...)) end ### @@ -133,107 +229,95 @@ rules that may be implemented in the future. symtype(x) = typeof(x) @inline symtype(::Symbolic{T}) where T = T @inline symtype(::Type{<:Symbolic{T}}) where T = T +@inline symtype(::BSImpl.Type{T}) where T = T +@inline symtype(::Type{<:BSImpl.Type{T}}) where T = T # We're returning a function pointer -@inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f - Add => (+) - Mul => (*) - Div => (/) - Pow => (^) - Sym => error_sym() - _ => error_on_type() - end -end - -@inline head(x::BasicSymbolic) = operation(x) - -@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::Vector{Any} - args = copy(arguments(x)) - @compactified x::BasicSymbolic begin - Add => @goto ADD - Mul => @goto MUL - _ => return args - end - @label MUL - sort!(args, by=get_degrees) - return args - - @label ADD - sort!(args, lt = monomial_lt, by=get_degrees) - return args -end - -@deprecate unsorted_arguments(x) arguments(x) - -TermInterface.children(x::BasicSymbolic) = arguments(x) -TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) -function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => return x.arguments - Add => @goto ADDMUL - Mul => @goto ADDMUL - Div => @goto DIV - Pow => @goto POW - Sym => error_sym() - _ => error_on_type() - end - - @label ADDMUL - E = exprtype(x) - args = x.arguments - isempty(args) || return args - siz = length(x.dict) - idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff) - sizehint!(args, idcoeff ? siz : siz + 1) - idcoeff || push!(args, x.coeff) - if isadd(x) - for (k, v) in x.dict - push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v], nothing)) - end - else # MUL - for (k, v) in x.dict - push!(args, unstable_pow(k, v)) - end +@inline function TermInterface.operation(x::BSImpl.Type) + @match x begin + # BSImpl.Const(_) => throw(ArgumentError("`Const` does not have an operation.")) + BSImpl.Sym(_) => throw(ArgumentError("`Sym` does not have an operation.")) + BSImpl.Term(; f) => f + BSImpl.AddOrMul(; variant) => @match variant begin + AddMulVariant.ADD => (+) + AddMulVariant.MUL => (*) + end + BSImpl.Div(_) => (/) + BSImpl.Pow(_) => (^) + _ => throw(UnimplementedForVariantError(operation, MData.variant_type(x))) end - return args - - @label DIV - args = x.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.num) - push!(args, x.den) - return args - - @label POW - args = x.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.base) - push!(args, x.exp) - return args end -isexpr(s::BasicSymbolic) = !issym(s) -iscall(s::BasicSymbolic) = isexpr(s) +@cache function TermInterface.sorted_arguments(x::BSImpl.Type)::ROArgsT + @match x begin + BSImpl.AddOrMul(; variant) => begin + args = copy(parent(arguments(x))) + @match variant begin + AddMulVariant.ADD => sort!(args, by = get_degrees, lt = monomial_lt) + AddMulVariant.MUL => sort!(args, by = get_degrees) + end + return ROArgsT(ArgsT(args)) + end + _ => return arguments(x) + end +end + +function TermInterface.arguments(x::BSImpl.Type)::ROArgsT + @match x begin + # BSImpl.Const(_) => throw(ArgumentError("`Const` does not have arguments.")) + BSImpl.Sym(_) => throw(ArgumentError("`Sym` does not have arguments.")) + BSImpl.Term(; args) => ROArgsT(args) + BSImpl.AddOrMul(; args, coeff, dict, variant) => begin + isempty(args) || return ROArgsT(args) + sz = length(dict) + idcoeff = is_identity_coeff(variant, coeff) + sizehint!(args, sz + !idcoeff) + idcoeff || push!(args, coeff) + @match variant begin + AddMulVariant.ADD => begin + for (k, v) in dict + var = if isone(v) + k + elseif applicable(*, k, v) + k * v + else + maketerm(k, *, [k, v], nothing) + end + push!(args, var) + end + end + AddMulVariant.MUL => begin + for (k, v) in dict + push!(args, isone(v) ? k : (k ^ v)) + end + end + end + return ROArgsT(args) + end + BSImpl.Div(num, den) => ROArgsT(ArgsT((num, den))) + BSImpl.Pow(base, exp) => ROArgsT(ArgsT((base, exp))) + _ => throw(UnimplementedForVariantError(arguments, MData.variant_type(x))) + end +end -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +function isexpr(s::BSImpl.Type) + !MData.isa_variant(s, BSImpl.Sym) # && !MData.isa_variant(s.inner, BSImpl.Const) +end +iscall(s::BSImpl.Type) = isexpr(s) -""" - issym(x) +# isconst(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.Const) +isconst(x) = false +issym(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.Sym) +isterm(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.Term) +isaddmul(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.AddOrMul) +isadd(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.AddOrMul) && x.variant == AddMulVariant.ADD +ismul(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.AddOrMul) && x.variant == AddMulVariant.MUL +isdiv(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.Div) +ispow(x::BSImpl.Type) = MData.isa_variant(x, BSImpl.Pow) -Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined -on `x` and must return a `Symbol`. -""" -issym(x) = isa_SymType(Val(:Sym), x) -isterm(x) = isa_SymType(Val(:Term), x) -ismul(x) = isa_SymType(Val(:Mul), x) -isadd(x) = isa_SymType(Val(:Add), x) -ispow(x) = isa_SymType(Val(:Pow), x) -isdiv(x) = isa_SymType(Val(:Div), x) +for fname in [:issym, :isterm, :isaddmul, :isadd, :ismul, :isdiv, :ispow] + @eval $fname(x) = false +end ### ### Base interface @@ -244,284 +328,284 @@ Base.isequal(x, ::Symbolic) = false Base.isequal(::Symbolic, ::Missing) = false Base.isequal(::Missing, ::Symbolic) = false Base.isequal(::Symbolic, ::Symbolic) = false -coeff_isequal(a, b; comparator = isequal) = comparator(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b)) -function _allarequal(xs, ys; comparator = isequal)::Bool - N = length(xs) - length(ys) == N || return false - for n = 1:N - comparator(xs[n], ys[n]) || return false - end - return true -end - -function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} - a === b && return true - a.id == b.id && a.id != 0 && return true - - E = exprtype(a) - E === exprtype(b) || return false - - T === S || return false - return _isequal(a, b, E)::Bool -end -function _isequal(a, b, E; comparator = isequal) - if E === SYM - nameof(a) === nameof(b) - elseif E === ADD || E === MUL - coeff_isequal(a.coeff, b.coeff; comparator) && comparator(a.dict, b.dict) - elseif E === DIV - comparator(a.num, b.num) && comparator(a.den, b.den) - elseif E === POW - comparator(a.exp, b.exp) && comparator(a.base, b.base) - elseif E === TERM - a1 = arguments(a) - a2 = arguments(b) - comparator(operation(a), operation(b)) && _allarequal(a1, a2; comparator) - else - error_on_type() - end -end - -""" -$(TYPEDSIGNATURES) - -Checks for equality between two `BasicSymbolic` objects, considering both their -values and metadata. -The default `Base.isequal` function for `BasicSymbolic` only compares their expressions -and ignores metadata. This does not help deal with hash collisions when metadata is -relevant for distinguishing expressions, particularly in hashing contexts. This function -provides a stricter equality check that includes metadata comparison, preventing -such collisions. - -Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and -downstream packages like `ModelingToolkit.jl`, hence the need for this separate -function. -""" -function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S} - a === b && return true - a.id == b.id && a.id != 0 && return true - - E = exprtype(a) - E === exprtype(b) || return false - - T === S || return false - _isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false -end - -function isequal_with_metadata(a::Symbolic, b::Symbolic)::Bool - a === b && return true - typeof(a) == typeof(b) || return false - - ma = metadata(a) - mb = metadata(b) - if iscall(a) && iscall(b) - return isequal_with_metadata(operation(a), operation(b)) && isequal_with_metadata(arguments(a), arguments(b)) && isequal_with_metadata(ma, mb) - elseif iscall(a) || iscall(b) - return false +Base.@nospecializeinfer function isequal_maybe_scal(a, b, full::Bool) + @nospecialize a b + if a isa BasicSymbolic{Number} && b isa BasicSymbolic{Number} + isequal_bsimpl(a, b, full) + elseif a isa Int && b isa Int + isequal(a, b) + elseif a isa Float64 && b isa Float64 + isequal(a, b) + elseif a isa Rational{Int} && b isa Rational{Int} + isequal(a, b) else - return isequal_with_metadata(ma, mb) + isequal(a, b)::Bool end end -""" - $(TYPEDSIGNATURES) +const COMPARE_FULL = TaskLocalValue{Bool}(Returns(false)) -Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively calling -`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal -metadata. -""" -function isequal_with_metadata(a::NamedTuple, b::NamedTuple) - a === b && return true - typeof(a) == typeof(b) || return false +macro manually_scope(val, expr, is_forced = false) + @assert Meta.isexpr(val, :call) + @assert val.args[1] == :(=>) - # same type, so same keys and value types - # either everything works or it fails and early exits - for (av, bv) in zip(values(a), values(b)) - isequal_with_metadata(av, bv) || return false + var_name = val.args[2] + new_val = val.args[3] + old_name = gensym(:old_val) + cur_name = gensym(:cur_val) + retval_name = gensym(:retval) + close_expr = :($var_name[] = $old_name) + interpolated_expr = MacroTools.postwalk(expr) do ex + if Meta.isexpr(ex, :return) + return Expr(:block, close_expr, ex) + elseif Meta.isexpr(ex, :$) && length(ex.args) == 1 && ex.args[1] == :$ + return cur_name + else + return ex + end end - - return true + basic_result = quote + $cur_name = $var_name[] = $new_val + $retval_name = begin + $interpolated_expr + end + $close_expr + $retval_name + end + is_forced && return quote + $old_name = $var_name[] + $basic_result + end |> esc + + return quote + $old_name = $var_name[] + if $iszero($old_name) + $basic_result + else + $cur_name = $old_name + $retval_name = begin + $interpolated_expr + end + end + $retval_name + end |> esc end -function isequal_with_metadata(a::AbstractDict, b::AbstractDict) - a === b && return true - typeof(a) == typeof(b) || return false +function isequal_symdict(a::Dict, b::Dict, full) + full || return isequal(a, b) length(a) == length(b) || return false - - # they have same length, so either `b` has all the same keys - # or this will fail. Can't use `get(b, k, nothing)` because if - # `a[k] === nothing` it will result in a false positive. for (k, v) in a - k2 = getkey(b, k, nothing) - isequal_with_metadata(k, k2) || return false - isequal_with_metadata(v, b[k2]) || return false + k2 = nothing + v2 = nothing + @manually_scope COMPARE_FULL => false begin + k2 = getkey(b, k, nothing) + k2 === nothing && return false + v2 = b[k2] + end true + v == v2 && isequal_bsimpl(k, k2, true) || return false end return true end -function isequal_with_metadata(a::Base.ImmutableDict, b::Base.ImmutableDict) +function isequal_bsimpl(a::BSImpl.Type, b::BSImpl.Type, full) a === b && return true - typeof(a) == typeof(b) || return false - length(a) == length(b) || return false - - for (k, v) in a - match = false - for (k2, v2) in b - match |= isequal_with_metadata(k, k2) && isequal_with_metadata(v, v2) - end - match || return false - end - return true -end - -""" - $(TYPEDSIGNATURES) - -Fallback method which uses `isequal`. -""" -isequal_with_metadata(a, b) = isequal(a, b) + ida = a.id + idb = b.id + ida === idb && ida !== nothing && return true + typeof(a) === typeof(b) || return false -""" - $(TYPEDSIGNATURES) + Ta = MData.variant_type(a) + Tb = MData.variant_type(b) + Ta === Tb || return false -Specialized methods to check if two ranges are equal without comparing each element. -""" -isequal_with_metadata(a::AbstractRange, b::AbstractRange) = isequal(a, b) -""" - $(TYPEDSIGNATURES) + if full && ida !== idb && ida !== nothing && idb !== nothing + return false + end -Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each element. -This is to ensure true equality of any symbolic elements, if present. -""" -function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple}) - a === b && return true - typeof(a) == typeof(b) || return false - if a isa AbstractArray - size(a) == size(b) || return false - end # otherwise they're tuples and type equality also checks length equality - for (x, y) in zip(a, b) - isequal_with_metadata(x, y) || return false + partial = @match (a, b) begin + (BSImpl.Sym(; name = n1, shape = s1), BSImpl.Sym(; name = n2, shape = s2)) => begin + n1 === n2 && s1 == s2 + end + (BSImpl.Term(; f = f1, args = args1, shape = s1), BSImpl.Term(; f = f2, args = args2, shape = s2)) => begin + isequal(f1, f2)::Bool && isequal(args1, args2) && s1 == s2 + end + (BSImpl.AddOrMul(; variant = v1, dict = d1, coeff = c1), BSImpl.AddOrMul(; variant = v2, dict = d2, coeff = c2)) => begin + v1 == v2 && isequal_symdict(d1, d2, full) && isequal_maybe_scal(c1, c2, full) + end + (BSImpl.Div(; num = n1, den = d1), BSImpl.Div(; num = n2, den = d2)) => begin + isequal_maybe_scal(n1, n2, full) && isequal_maybe_scal(d1, d2, full) + end + (BSImpl.Pow(; base = n1, exp = d1), BSImpl.Pow(; base = n2, exp = d2)) => begin + isequal_maybe_scal(n1, n2, full) && isequal_maybe_scal(d1, d2, full) + end end - return true + if full && partial + partial = isequal(metadata(a), metadata(b)) + end + return partial end -isequal_with_metadata(a::Number, b::Number) = typeof(a) == typeof(b) && isequal(a, b) - -Base.one( s::Symbolic) = one( symtype(s)) -Base.zero(s::Symbolic) = zero(symtype(s)) - -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +function Base.isequal(a::BSImpl.Type, b::BSImpl.Type) + isequal_bsimpl(a, b, COMPARE_FULL[]) +end -## This is much faster than hash of an array of Any -hashvec(xs, z) = foldr(hash, xs, init=z) -hashvec2(xs, z) = foldr(hash2, xs, init=z) +# const CONST_SALT = 0x194813feb8a8c83d % UInt const SYM_SALT = 0x4de7d7c66d41da43 % UInt const ADD_SALT = 0xaddaddaddaddadda % UInt -const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt +const MUL_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt -function Base.hash(s::BasicSymbolic, salt::UInt)::UInt - E = exprtype(s) - if E === SYM - hash(nameof(s), salt ⊻ SYM_SALT) - elseif E === ADD || E === MUL - !iszero(salt) && return hash(hash(s, zero(UInt)), salt) - h = s.hash[] - !iszero(h) && return h - hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt))) - s.hash[] = h′ - return h′ - elseif E === DIV - return hash(s.num, hash(s.den, salt ⊻ DIV_SALT)) - elseif E === POW - hash(s.exp, hash(s.base, salt ⊻ POW_SALT)) - elseif E === TERM - !iszero(salt) && return hash(hash(s, zero(UInt)), salt) - h = s.hash[] - !iszero(h) && return h - op = operation(s) - oph = op isa Function ? nameof(op) : op - h′ = hashvec(arguments(s), hash(oph, salt)) - s.hash[] = h′ - return h′ + +const SCALAR_SYMTYPE_VARIANTS = [Number, Real, SafeReal, LiteralReal] +const ARR_VARIANTS = [Vector, Matrix] +const SYMTYPE_VARIANTS = [SCALAR_SYMTYPE_VARIANTS; [A{T} for A in ARR_VARIANTS for T in SCALAR_SYMTYPE_VARIANTS]] + +Base.@nospecializeinfer function hash_coeff(x::Number, h::UInt) + @nospecialize x + if x isa Int + hash(x, h) + elseif x isa Float64 + hash(x, h) + elseif x isa Rational{Int} + hash(x, h) + elseif x isa UInt + hash(x, h) + elseif x isa Bool + hash(x, h) + else + hash(x, h)::UInt + end +end + +Base.@nospecializeinfer function hash_anyscalar(x::Any, h::UInt, full::Bool) + @nospecialize x + if x isa Int + hash(x, h) + elseif x isa Float64 + hash(x, h) + elseif x isa Rational{Int} + hash(x, h) + elseif x isa UInt + hash(x, h) + elseif x isa Bool + hash(x, h) + elseif x isa BasicSymbolic{Number} + hash_bsimpl(x, h, full) else - error_on_type() + hash(x, h)::UInt end end -""" -$(TYPEDSIGNATURES) +function custom_dicthash(x::Dict{Symbolic, Number}, h::UInt, full) + hv = Base.hasha_seed + for (k, v) in x + h1 = hash_anyscalar(v, zero(UInt), full) + h1 = hash_anyscalar(k, h1, full) + hv ⊻= h1 + end + return hash(hv, h) +end -Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and -symtype. +Base.@nospecializeinfer function hash_addmuldict(x::Dict, h::UInt, full) + @nospecialize x + if x isa Dict{Symbolic, Number} + custom_dicthash(x, h, full) + else + hash(x, h)::UInt + end +end -This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic` -objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also -includes the metadata and symtype in the hash calculation. This can be beneficial for hash -consing, allowing for more effective deduplication of symbolically equivalent expressions -with different metadata or symtypes. +function hashargs(x::ArgsT, h::UInt, full) + h += Base.hash_abstractarray_seed + h = hash(length(x), h) + for val in x + h = hash_anyscalar(val, h, full) + end + return h +end -Equivalent numbers of different types, such as `0.5::Float64` and -`(1 // 2)::Rational{Int64}`, have the same default `Base.hash` value. The `hash2` function -distinguishes these by including their numeric types in the hash calculation to ensure that -symbolically equivalent expressions with different numeric types are treated as distinct -objects. -""" -hash2(s, salt::UInt) = hash(s, salt) -function hash2(n::T, salt::UInt) where {T <: Number} - hash(T, hash(n, salt)) -end -hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) -function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} - E = exprtype(s) - h::UInt = 0 - if E === SYM - h = hash(nameof(s), salt ⊻ SYM_SALT) - elseif E === ADD || E === MUL - if !iszero(s.hash2[]) - return s.hash2[] - end - hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - hv = Base.hasha_seed - for (k, v) in s.dict - hv ⊻= hash2(k, hash(v)) +function hash_bsimpl(s::BSImpl.Type, h::UInt, full) + if full + cache = s.hash2 + !iszero(cache) && return cache + end + + partial::UInt = @match s begin + BSImpl.Sym(; name, shape) => begin + h = Base.hash(name, h) + h = Base.hash(shape, h) + h ⊻ SYM_SALT + end + BSImpl.Term(; f, args, shape, hash) => begin + # use/update cached hash + # error() + cache = hash + if iszero(cache) + s.hash = Base.hash(f, hashargs(args, Base.hash(shape, h), full))::UInt + else + cache + end + end + BSImpl.AddOrMul(; variant, dict, coeff, shape, hash) => begin + cache = hash + if iszero(cache) + inner = hash_addmuldict(dict, h, full) + inner = Base.hash(shape, hash_coeff(coeff, inner)) + inner = Base.hash((variant == AddMulVariant.ADD ? ADD_SALT : MUL_SALT), inner) + s.hash = inner + else + cache + end + + end + BSImpl.Div(; num, den) => begin + hash_anyscalar(num, hash_anyscalar(den, h, full), full) ⊻ DIV_SALT + end + BSImpl.Pow(; base, exp) => begin + hash_anyscalar(base, hash_anyscalar(exp, h, full), full) ⊻ POW_SALT end - h = hash(hv, salt) - h = hash(hashoffset, hash2(s.coeff, h)) - elseif E === DIV - h = hash2(s.num, hash2(s.den, salt ⊻ DIV_SALT)) - elseif E === POW - h = hash2(s.exp, hash2(s.base, salt ⊻ POW_SALT)) - elseif E === TERM - if !iszero(s.hash2[]) - return s.hash2[] - end - op = operation(s) - oph = op isa Function ? nameof(op) : op - h = hashvec2(arguments(s), hash(oph, salt)) - else - error_on_type() end - h = hash(metadata(s), hash(T, h)) - if hasproperty(s, :hash2) - s.hash2[] = h + + if full + partial = s.hash2 = Base.hash(metadata(s), partial)::UInt end - return h + return partial +end + +function Base.hash(s::BSImpl.Type, h::UInt) + if !iszero(h) + return hash(hash(s, zero(h)), h)::UInt + end + hash_bsimpl(s, h, COMPARE_FULL[]) end +Base.one( s::Union{Symbolic, BSImpl.Type}) = one( symtype(s)) +Base.zero(s::Union{Symbolic, BSImpl.Type}) = zero(symtype(s)) + + +Base.nameof(s::Union{BasicSymbolic, BSImpl.Type}) = issym(s) ? s.name : error("Non-Sym BasicSymbolic doesn't have a name") + ### ### Constructors ### -const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic}) +const ENABLE_HASHCONSING = Ref(true) +# const WKD = TaskLocalValue{WeakKeyDict{HashconsingWrapper, Nothing}}(WeakKeyDict{HashconsingWrapper, Nothing}) +const WKD = TaskLocalValue{WeakKeyDict{BSImpl.Type, Nothing}}(WeakKeyDict{BSImpl.Type, Nothing}) +const WVD = TaskLocalValue{WeakValueDict{UInt, BSImpl.Type}}(WeakValueDict{UInt, BSImpl.Type}) +const WCS = TaskLocalValue{WeakCacheSet{BSImpl.Type}}(WeakCacheSet{BSImpl.Type}) function generate_id() return IDType() end +const TOTAL = TaskLocalValue{Int}(Returns(0)) +const HITS = TaskLocalValue{Int}(Returns(0)) +const MISSES = TaskLocalValue{Int}(Returns(0)) +const COLLISIONS = TaskLocalValue{Int}(Returns(0)) + """ $(TYPEDSIGNATURES) @@ -543,41 +627,97 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the original behavior of those functions. """ -function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic + +const collides = TaskLocalValue{Any}(Returns(Dict())) + +function hashcons(s::BSImpl.Type{T})::BSImpl.Type{T} where {T} if !ENABLE_HASHCONSING[] return s end - - cache = wvd[] - h = hash2(s) - k = get!(cache, h, s) - if isequal_with_metadata(k, s) - if isnothing(k.id[]) - k.id[] = generate_id() + @manually_scope COMPARE_FULL => true begin + cache = WCS[] + k = getkey!(cache, s) + # cache = WVD[] + # h = hash(s) + # k = get(cache, h, nothing) + + # if k === nothing || !isequal(k, s) + # if k !== nothing + # buffer = collides[] + # buffer2 = get!(() -> [], buffer, h) + # push!(buffer2, k => s) + # end + + # cache[h] = s + # k = s + # end + if k.id === nothing + k.id = generate_id() end return k - else - if isnothing(s.id[]) - s.id[] = generate_id() - end - return s + end true +end + +# function BSImpl.Const{T}(val::T) where {T} +# hashcons(BSImpl.Const{T}(; val, override_properties(BSImpl.Const{T})...)) +# end + +parse_metadata(x::MetadataT) = x +parse_metadata(::Nothing) = nothing +function parse_metadata(x) + meta = MetadataT() + for kvp in x + meta = Base.ImmutableDict(meta, kvp) end + return meta end -function Sym{T}(name::Symbol; kw...) where {T} - s = Sym{T}(; name, kw..., id = Ref{TID}(DID)) - BasicSymbolic(s) +default_shape(::Type{<:AbstractArray}) = Unknown() +default_shape(_) = ShapeVecT() + +""" + $(METHODLIST) + +If `x` is a rational with denominator 1, turn it into an integer. +""" +function maybe_integer(x) + x = unwrap(x) + x isa Real || return x + isinteger(x) || return x + if typemin(Int) <= x <= typemax(Int) + return Int(x) + else + return x + end end -function Term{T}(f, args; kw...) where T - args = SmallV{Any}(args) +function parse_args(args::AbstractVector) + if args isa ROArgsT + args = parent(args) + elseif !(args isa ArgsT) + args = ArgsT(args) + end + return args::ArgsT +end - s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw..., id = Ref{TID}(DID)) - BasicSymbolic(s) +function parse_dict(::Type{T}, x::AbstractDict) where {T} + if !(x isa ACDict{Symbolic, T}) + x = ACDict{Symbolic, T}(x) + end + map!(maybe_integer, values(x)) + return x::ACDict{Symbolic, T} end -function Term(f, args; metadata=NO_METADATA) - Term{_promote_symtype(f, args)}(f, args, metadata=metadata) +parse_maybe_symbolic(x::Symbolic) = x +parse_maybe_symbolic(x) = x +# parse_maybe_symbolic(x) = Const{typeof(x)}(x) + +function unwrap_args(args) + if any(x -> unwrap(x) !== x, args) + map(unwrap, args) + else + args + end end function unwrap_dict(dict) @@ -587,46 +727,144 @@ function unwrap_dict(dict) return dict end -function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T - coeff = unwrap(coeff) +@inline function BSImpl.Sym{T}(name::Symbol; metadata = nothing, shape = default_shape(T), unsafe = false) where {T} + metadata = parse_metadata(metadata) + props = ordered_override_properties(BSImpl.Sym) + var = BSImpl.Sym{T}(name, metadata, shape, props...) + if !unsafe + var = hashcons(var) + end + return var +end + +@inline function BSImpl.Term{T}(f, args; metadata = nothing, shape = default_shape(T), unsafe = false) where {T} + metadata = parse_metadata(metadata) + args = parse_args(args) + props = ordered_override_properties(BSImpl.Term) + var = BSImpl.Term{T}(f, args, metadata, shape, props...) + if !unsafe + var = hashcons(var) + end + return var +end + +@inline function BSImpl.AddOrMul{T}(variant::AddMulVariant.T, coeff::T, dict::AbstractDict; metadata = nothing, shape = default_shape(T), unsafe = false) where {T} + metadata = parse_metadata(metadata) + dict = parse_dict(T, dict) + props = ordered_override_properties(BSImpl.AddOrMul) + coeff = maybe_integer(coeff) + var = BSImpl.AddOrMul{T}(variant, coeff, dict, metadata, shape, props...) + if !unsafe + var = hashcons(var) + end + return var +end + +@inline function BSImpl.Div{T}(num, den, simplified::Bool; metadata = nothing, shape = default_shape(T), unsafe = false) where {T} + metadata = parse_metadata(metadata) + num = maybe_integer(parse_maybe_symbolic(num)) + den = maybe_integer(parse_maybe_symbolic(den)) + props = ordered_override_properties(BSImpl.Div) + var = BSImpl.Div{T}(num, den, simplified, metadata, shape, props...) + if !unsafe + var = hashcons(var) + end + return var +end + +@inline function BSImpl.Pow{T}(base, exp; metadata = nothing, shape = default_shape(T), unsafe = false) where {T} + metadata = parse_metadata(metadata) + base = maybe_integer(parse_maybe_symbolic(base)) + exp = maybe_integer(parse_maybe_symbolic(exp)) + props = ordered_override_properties(BSImpl.Pow) + var = BSImpl.Pow{T}(base, exp, metadata, shape, props...) + if !unsafe + var = hashcons(var) + end + return var +end + +# struct Const{T} end +struct Sym{T} end +struct Term{T} end +struct Add{T} end +struct Mul{T} end +struct Div{T} end +struct Pow{T} end + +# function Const{T}(val)::Symbolic where {T} +# val = unwrap(val) +# val isa Symbolic && return val +# BasicSymbolic(BSImpl.Const{T}(convert(T, val))) +# end + +# Const(val) = Const{typeof(val)}(val) + +Sym{T}(name; kw...) where {T} = BSImpl.Sym{T}(name; kw...) + +function Term{T}(f, args; kw...) where {T} + args = unwrap_args(args) + BSImpl.Term{T}(f, args; kw...) +end + +function Term(f, args; kw...) + Term{_promote_symtype(f, args)}(f, args; kw...) +end + +# assumes associative commutative addition +function Add{T}(coeff, dict; kw...) where {T} + coeff = convert(T, maybe_integer(unwrap(coeff))) dict = unwrap_dict(dict) - if isempty(dict) - return coeff - elseif _iszero(coeff) && length(dict) == 1 - k,v = first(dict) - if _isone(v) - return k - else - coeff, dict = makemul(v, k) - return Mul(T, coeff, dict) - end + isempty(dict) && return coeff + if _iszero(coeff) && length(dict) == 1 + k, v = first(dict) + _isone(v) && return k + return k * v end - s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{TID}(DID)) - BasicSymbolic(s) + variant = AddMulVariant.ADD + BSImpl.AddOrMul{T}(variant, coeff, dict, kw...) end -function Mul(T, a, b; metadata=NO_METADATA, kw...) - a = unwrap(a) - b = unwrap_dict(b) - isempty(b) && return a - if _isone(a) && length(b) == 1 - pair = first(b) - if _isone(last(pair)) # first value - return first(pair) - else - return unstable_pow(first(pair), last(pair)) - end - else - coeff = a - dict = b - s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{TID}(DID)) - BasicSymbolic(s) +function Mul{T}(coeff, dict; kw...) where {T} + coeff = convert(T, maybe_integer(unwrap(coeff))) + dict = unwrap_dict(dict) + isempty(dict) && return coeff + if _isone(coeff) && length(dict) == 1 + k, v = first(dict) + _isone(v) && return k + return k ^ v end + + variant = AddMulVariant.MUL + BSImpl.AddOrMul{T}(variant, coeff, dict; kw...) +end + +""" + $(TYPEDSIGNATURES) + +Create a generic division term. Does not assume anything about the division algebra beyond +the ability to check for zero and one elements (via [`_iszero`](@ref) and [`_isone`](@ref)). + +If the numerator is zero or denominator is one, the numerator is returned. +""" +function Div{T}(n, d, simplified; kw...) where {T} + n = unwrap(n) + d = unwrap(d) + # TODO: This used to return `zero(typeof(n))`, maybe there was a reason? + _iszero(n) && return n + _isone(d) && return n + return BSImpl.Div{T}(n, d, simplified; kw...) end const Rat = Union{Rational, Integer} +""" + $(TYPEDSIGNATURES) + +Return a tuple containing a boolean indicating whether `x` has a rational/integer factor +and the rational/integer factor (or `NaN` otherwise). +""" function ratcoeff(x) if ismul(x) ratcoeff(x.coeff) @@ -636,164 +874,175 @@ function ratcoeff(x) (false, NaN) end end -ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y -ratio(x::Rat,y::Rat) = x//y -function maybe_intcoeff(x) - if ismul(x) - if x.coeff isa Rational && isone(x.coeff.den) - Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=SmallV{Any}()) - else - x - end - elseif x isa Rational - isone(x.den) ? x.num : x - else - x - end + +""" + $(TYPEDSIGNATURES) + +Simplify the coefficients of `n` and `d` (numerator and denominator). +""" +function simplify_coefficients(n, d) + nrat, nc = ratcoeff(n) + drat, dc = ratcoeff(d) + nrat && drat || return n, d + g = gcd(nc, dc) * sign(dc) # make denominator positive + invdc = isone(g) ? g : (1 // g) + n = maybe_integer(invdc * n) + d = maybe_integer(invdc * d) + + return n, d end -function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T} +""" + $(TYPEDSIGNATURES) + +Create a division term specifically for the real or complex algebra. Performs additional +simplification and cancellation. +""" +function Div{T}(n, d, simplified; kw...) where {T <: Number} n = unwrap(n) d = unwrap(d) - if T<:Number && !(T<:SafeReal) + + if !(T == SafeReal) n, d = quick_cancel(n, d) end + _iszero(n) && return zero(typeof(n)) _isone(d) && return n if isdiv(n) && isdiv(d) - return Div{T}(n.num * d.den, n.den * d.num) + return Div{T}(n.num * d.den, n.den * d.num, simplified; kw...) elseif isdiv(n) - return Div{T}(n.num, n.den * d) + return Div{T}(n.num, n.den * d, simplified; kw...) elseif isdiv(d) - return Div{T}(n * d.den, d.num) - end - - d isa Number && _isone(-d) && return -1 * n - n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify - - # GCD coefficient upon construction - rat, nc = ratcoeff(n) - if rat - rat, dc = ratcoeff(d) - if rat - g = gcd(nc, dc) * sign(dc) # make denominator positive - invdc = ratio(1, g) - n = maybe_intcoeff(invdc * n) - d = maybe_intcoeff(invdc * d) - if d isa Number - _isone(d) && return n - _isone(-d) && return -1 * n - end - end + return Div{T}(n * d.den, d.num, simplified; kw...) end - s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata, id = Ref{TID}(DID)) - BasicSymbolic(s) + d isa Number && _isone(-d) && return -n + n isa Rat && d isa Rat && return n // d + + n, d = simplify_coefficients(n, d) + + _isone(d) && return n + _isone(-d) && return -n + + BSImpl.Div{T}(n, d, simplified; kw...) end -function Div(n,d, simplified=false; kw...) +function Div(n, d, simplified; kw...) Div{promote_symtype((/), symtype(n), symtype(d))}(n, d, simplified; kw...) end +""" + $(TYPEDSIGNATURES) + +Return the numerator of expression `x` as an array of multiplied terms. +""" @inline function numerators(x) isdiv(x) && return numerators(x.num) - iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] + iscall(x) && operation(x) === (*) && return arguments(x) + return SmallV{Any}((x,)) end -@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] +""" + $(TYPEDSIGNATURES) -function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T} - a = unwrap(a) - b = unwrap(b) - _iszero(b) && return 1 - _isone(b) && return a - s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata, id = Ref{TID}(DID)) - BasicSymbolic(s) +Return the denominator of expression `x` as an array of multiplied terms. +""" +@inline denominators(x) = isdiv(x) ? numerators(x.den) : SmallV{Any}((1,)) + +function Pow{T}(base, exp; kw...) where {T} + base = unwrap(base) + exp = unwrap(exp) + # TODO: Returning 1 isn't valid for matrix algebra + # This should use a `_one` function + _iszero(exp) && return 1 + _isone(exp) && return base + return BSImpl.Pow{T}(base, exp; kw...) end -function Pow(a, b; metadata = NO_METADATA, kwargs...) - Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; metadata, kwargs...) +function Pow(a, b; kw...) + Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; kw...) end -function toterm(t::BasicSymbolic{T}) where T - E = exprtype(t) - if E === SYM || E === TERM - return t - elseif E === ADD || E === MUL - args = Any[] - push!(args, t.coeff) - for (k, coeff) in t.dict - push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), SmallV{Any}((coeff, k)))) - end - Term{T}(operation(t), args) - elseif E === DIV - Term{T}(/, SmallV{Any}((t.num, t.den))) - elseif E === POW - Term{T}(^, SmallV{Any}((t.base, t.exp))) - else - error_on_type() +function _mergedict!(dict::AbstractDict, other::AbstractDict) + for (k, v) in other + vv = get(dict, k, 0) + dict[k] = v + vv end end +function unwrap_const(x) + x + # isconst(x) ? x.val : x +end + """ - makeadd(sign, coeff::Number, xs...) + $(TYPEDSIGNATURES) -Any Muls inside an Add should always have a coeff of 1 -and the key (in Add) should instead be used to store the actual coefficient +Return the `coeff` and `dict` for adding `xs...` into a symbolic of symtype `T`. """ -function makeadd(sign, coeff, xs...) - d = sdict() +function makeadd(::Type{T}, xs...)::Tuple{T, Dict{Symbolic, T}} where {T} + dict = Dict{Symbolic, T}() + coeff = zero(T) + for x in xs - if isadd(x) - coeff += x.coeff - _merge!(+, d, x.dict, filter=_iszero) - continue - end + x = unwrap_const(unwrap(x)) if x isa Number - coeff += x + coeff += convert(T, x) + continue + elseif isadd(x) + coeff += x.coeff + _mergedict!(dict, x.dict) continue end if ismul(x) - k = Mul(symtype(x), 1, x.dict) - v = sign * x.coeff + get(d, k, 0) + v = x.coeff + k = Mul{T}(1, x.dict; metadata = metadata(x)) else k = x - v = sign + get(d, x, 0) - end - if iszero(v) - delete!(d, k) - else - d[k] = v + v = 1 end + dict[k] = get(dict, k, zero(T)) + v end - coeff, d + filter!(!iszero ∘ last, dict) + return coeff, dict end -function makemul(coeff, xs...; d=sdict()) +""" + $(TYPEDSIGNATURES) + +Return the `coeff` and `dict` for multiplying `xs...` into a symbolic of symtype `T`. +""" +function makemul(::Type{T}, xs...) where {T} + dict = Dict{Symbolic, T}() + coeff = one(T) for x in xs - if ispow(x) && x.exp isa Number - d[x.base] = x.exp + get(d, x.base, 0) + x = unwrap_const(unwrap(x)) + if ispow(x) && x.exp isa T + # if ispow(x) && isconst(x.exp) + dict[x.base] = x.exp + get(dict, x.base, 0) elseif x isa Number - coeff *= x + coeff *= convert(T, x) elseif ismul(x) coeff *= x.coeff - _merge!(+, d, x.dict, filter=_iszero) + _mergedict!(dict, x.dict) else - v = 1 + get(d, x, 0) - if _iszero(v) - delete!(d, x) - else - d[x] = v - end + dict[x] = get(dict, x, 0) + 1 end end - (coeff, d) + + filter!(!iszero ∘ last, dict) + return (coeff, dict) end -unstable_pow(a, b) = a isa Integer && b isa Integer ? (a//1) ^ b : a ^ b +""" + $(TYPEDSIGNATURES) +Return the base and exponent for representing `a^b`. +""" function makepow(a, b) + a = unwrap(a) + b = unwrap(b) base = a exp = b if ispow(a) @@ -803,6 +1052,9 @@ function makepow(a, b) return (base, exp) end +""" + $(TYPEDSIGNATURES) +""" function term(f, args...; type = nothing) args = SmallV{Any}(args) if type === nothing @@ -813,24 +1065,8 @@ function term(f, args...; type = nothing) Term{T}(f, args) end -""" - unflatten(t::Symbolic{T}) -Binarizes `Term`s with n-ary operations -""" -function unflatten(t::Symbolic{T}) where{T} - if iscall(t) - f = operation(t) - if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops - a = arguments(t) - return foldl((x,y) -> Term{T}(f, SmallV{Any}((x, y))), a) - end - end - return t -end - -unflatten(t) = t - function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) + args = unwrap_args(args) st = symtype(T) pst = _promote_symtype(head, args) # Use promoted symtype only if not a subtype of the existing symtype of T. @@ -854,35 +1090,36 @@ function basicsymbolic(f, args, stype, metadata) if f isa Symbol error("$f must not be a Symbol") end + args = unwrap_args(args) T = stype if T === nothing T = _promote_symtype(f, args) end - if T <: LiteralReal + if T == LiteralReal @goto FALLBACK elseif all(x->symtype(x) <: Number, args) if f === (+) - res = +(args...) - if isadd(res) || (isterm(res) && operation(res) == (+)) + res = add_worker(args) + if metadata !== nothing && (isadd(res) || (isterm(res) && operation(res) == (+))) @set! res.metadata = metadata end res elseif f == (*) - res = *(args...) - if ismul(res) || (isterm(res) && operation(res) == (*)) + res = mul_worker(args) + if metadata !== nothing && (ismul(res) || (isterm(res) && operation(res) == (*))) @set! res.metadata = metadata end res elseif f == (/) @assert length(args) == 2 res = args[1] / args[2] - if isdiv(res) + if metadata !== nothing && isdiv(res) @set! res.metadata = metadata end res elseif f == (^) && length(args) == 2 res = args[1] ^ args[2] - if ispow(res) + if metadata !== nothing && ispow(res) @set! res.metadata = metadata end res @@ -898,6 +1135,8 @@ end ### ### Metadata ### +metadata(s::BSImpl.Type) = s.metadata +# metadata(s::HashconsingWrapper) = throw(MethodError(metadata, (s,))) metadata(s::Symbolic) = s.metadata metadata(s::Symbolic, meta) = Setfield.@set! s.metadata = meta @@ -907,7 +1146,7 @@ end issafecanon(f, s) = true function issafecanon(f, s::Symbolic) - if isnothing(metadata(s)) || issym(s) + if metadata(s) === nothing || isempty(metadata(s)) || issym(s) return true else _issafecanon(f, s) @@ -965,11 +1204,6 @@ function setmetadata(s::Symbolic, ctx::DataType, val) end end - -function to_symbolic(x) - x -end - ### ### Pretty printing ### @@ -1123,7 +1357,7 @@ function show_term(io::IO, t) f = operation(t) args = sorted_arguments(t) - if symtype(t) <: LiteralReal + if symtype(t) == LiteralReal show_call(io, f, args) elseif f === (+) show_add(io, args) @@ -1145,9 +1379,11 @@ end showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) showraw(t) = showraw(stdout, t) -function Base.show(io::IO, v::BasicSymbolic) +function Base.show(io::IO, v::BSImpl.Type) if issym(v) Base.show_unquoted(io, v.name) + elseif isconst(v) + printstyled(io, v.val; color = :blue) else show_term(io, v) end @@ -1363,15 +1599,46 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) +function safe_add!(dict, coeff, b) + if isadd(b) + coeff += b.coeff + for (k, v) in b.dict + dict[k] = get(dict, k, 0) + v + end + elseif ismul(b) + v = b.coeff + metadata = b.metadata + if metadata === nothing + b′ = Mul{symtype(b)}(1, b.dict) + else + b′ = Mul{symtype(b)}(1, b.dict; metadata) + end + dict[b′] = get(dict, b′, 0) + v + elseif b isa Number + coeff += b + else + dict[b] = get(dict, b, 0) + 1 + end + return coeff +end + function +(a::SN, bs::SN...) + add_worker((a, bs...)) +end + +function add_worker(terms) + a, bs = Iterators.peel(terms) isempty(bs) && return a + T = symtype(a) + for b in bs + T = promote_symtype(+, T, symtype(b)) + end # entries where `!issafecanon` unsafes = SmallV{Any}() # coeff and dict of the `Add` coeff = 0 - dict = sdict() + dict = Dict{Symbolic, T}() # type of the `Add` - T = symtype(a) # handle `a` separately if issafecanon(+, a) @@ -1380,8 +1647,10 @@ function +(a::SN, bs::SN...) dict = copy(a.dict) elseif ismul(a) v = a.coeff - a′ = Mul(symtype(a), 1, copy(a.dict); metadata = a.metadata) + a′ = Mul{symtype(a)}(1, a.dict; metadata = a.metadata) dict[a′] = v + elseif a isa Number + coeff = a else dict[a] = 1 end @@ -1390,30 +1659,24 @@ function +(a::SN, bs::SN...) end for b in bs - T = promote_symtype(+, T, symtype(b)) if !issafecanon(+, b) push!(unsafes, b) continue end - if isadd(b) - coeff += b.coeff - for (k, v) in b.dict - dict[k] = get(dict, k, 0) + v - end - elseif ismul(b) - v = b.coeff - b′ = Mul(symtype(b), 1, copy(b.dict); metadata = b.metadata) - dict[b′] = get(dict, b′, 0) + v - else - dict[b] = get(dict, b, 0) + 1 - end + coeff = safe_add!(dict, coeff, b) end # remove entries multiplied by zero filter!(dict) do kvp !iszero(kvp[2]) end - - result = isempty(dict) ? coeff : Add(T, coeff, dict) + if isempty(dict) + result = coeff + elseif iszero(coeff) && length(dict) == 1 + expr, coeff = first(dict) + result = coeff * expr + else + result = Add{T}(coeff, dict) + end if !isempty(unsafes) push!(unsafes, result) result = Term{T}(+, unsafes) @@ -1425,10 +1688,11 @@ function +(a::Number, b::SN, bs::SN...) b = +(b, bs...) issafecanon(+, b) || return term(+, a, b) iszero(a) && return b + T = add_t(a, b) if isadd(b) - Add(add_t(a, b), a + b.coeff, b.dict) + Add{T}(a + b.coeff, b.dict) else - Add(add_t(a, b), makeadd(1, a, b)...) + Add{T}(makeadd(T, a, b)...) end end @@ -1438,13 +1702,12 @@ end function -(a::SN) !issafecanon(*, a) && return term(-, a) - isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) : - Add(sub_t(a), makeadd(-1, 0, a)...) + isadd(a) ? Add{sub_t(a)}(-a.coeff, mapvalues((_,v) -> -v, a.dict)) : (-1 * a) end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) - isadd(a) && isadd(b) ? Add(sub_t(a,b), + isadd(a) && isadd(b) ? Add{sub_t(a,b)}( a.coeff - b.coeff, _merge(-, a.dict, b.dict, @@ -1460,31 +1723,145 @@ mul_t(a) = promote_symtype(*, symtype(a)) *(a::SN) = a +# should not be called with a `div` +function get_mul_coeff_dict(::Type{T}, term; safe = false) where {T} + @match term begin + x::Number => (x, Dict{Symbolic, T}()) + BSImpl.AddOrMul(; variant = AddMulVariant.MUL, coeff, dict) => (coeff, copy(dict)) + BSImpl.Pow(; base, exp) && if exp isa Number end => (1, Dict{Symbolic, T}(base => exp)) + _ => (1, Dict{Symbolic, T}(term => 1)) + end +end + +function mul_worker(terms) + length(terms) == 1 && return only(terms) + a, bs = Iterators.peel(terms) + a = unwrap(a) + T = symtype(a) + for b in bs + T = promote_symtype(*, T, symtype(b)) + end + if isdiv(a) + num_coeff, num_dict = get_mul_coeff_dict(T, a.num) + den_coeff, den_dict = get_mul_coeff_dict(T, a.den) + else + num_coeff, num_dict = get_mul_coeff_dict(T, a) + den_coeff = 1 + den_dict = nothing + end + unsafes = SmallV{Any}() + + for b in bs + b = unwrap(b) + if !issafecanon(*, b) + push!(unsafes, b) + continue + end + @match b begin + x::Number => (num_coeff *= x) + BSImpl.AddOrMul(; variant = AddMulVariant.MUL, coeff, dict) => begin + num_coeff *= coeff + for (k, v) in dict + num_dict[k] = get(num_dict, k, 0) + v + end + end + BSImpl.Pow(; base, exp) && if exp isa Number end => begin + num_dict[base] = get(num_dict, base, 0) + exp + end + BSImpl.Div(; num, den) => begin + if den_dict === nothing && !(den isa Number) + den_dict = Dict{Symbolic, T}() + end + @match num begin + x::Number => (num_coeff *= x) + BSImpl.AddOrMul(; variant = AddMullVariant.MUL, coeff, dict) => begin + num_coeff *= coeff + for (k, v) in dict + num_dict[k] = get(num_dict, k, 0) + v + end + end + BSImpl.Pow(; base, exp) && if exp isa number end => begin + num_dict[base] = get(num_dict, base, 0) + exp + end + _ => (num_dict[num] = get(num_dict, num, 0) + 1) + end + @match den begin + x::Number => (den_coeff *= x) + BSImpl.AddOrMul(; variant = AddMulVariant.MUL, coeff, dict) => begin + den_coeff *= coeff + for (k, v) in dict + den_dict[k] = get(den_dict, k, 0) + v + end + end + BSImpl.Pow(; base, exp) && if exp isa Number end => begin + den_dict[base] = get(den_dict, base, 0) + exp + end + _ => (den_dict[den] = get(den_dict, k, 0)) + end + end + _ => (num_dict[b] = get(num_dict, b, 0) + 1) + end + end + + if iszero(num_coeff) + return num_coeff + end + filter!(kvp -> !iszero(kvp[2]), num_dict) + if isempty(num_dict) + num = num_coeff + elseif isone(num_coeff) && length(num_dict) == 1 + base, exp = first(num_dict) + num = Pow{T}(base, exp) + else + num = Mul{T}(num_coeff, num_dict) + end + + if !isempty(unsafes) + push!(unsafes, num) + num = Term{T}(*, unsafes) + end + + if den_dict !== nothing + filter!(kvp -> !iszero(kvp[2]), den_dict) + end + + if den_dict === nothing || isempty(den_dict) + den = den_coeff + elseif isone(den_coeff) && length(den_dict) == 1 + base, exp = first(den_dict) + den = Pow{T}(base, exp) + else + den = Mul{T}(den_coeff, den_dict) + end + + return Div{T}(num, den, false) +end + function *(a::SN, b::SN) # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) - Div(a.num * b.num, a.den * b.den) + Div(a.num * b.num, a.den * b.den, false) elseif isdiv(a) - Div(a.num * b, a.den) + Div(a.num * b, a.den, false) elseif isdiv(b) - Div(a * b.num, b.den) + Div(a * b.num, b.den, false) elseif ismul(a) && ismul(b) - Mul(mul_t(a, b), + Mul{mul_t(a, b)}( a.coeff * b.coeff, _merge(+, a.dict, b.dict, filter=_iszero)) elseif ismul(a) && ispow(b) if b.exp isa Number - Mul(mul_t(a, b), + Mul{mul_t(a, b)}( a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero)) else - Mul(mul_t(a, b), + Mul{mul_t(a, b)}( a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero)) end elseif ispow(a) && ismul(b) b * a else - Mul(mul_t(a,b), makemul(1, a, b)...) + Mul{mul_t(a,b)}(makemul(mul_t(a, b), a, b)...) end end @@ -1499,13 +1876,13 @@ function *(a::Number, b::SN) elseif isone(a) b elseif isdiv(b) - Div(a*b.num, b.den) + Div(a*b.num, b.den, false) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - Add(T, b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) + Add{T}(b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) else - Mul(mul_t(a, b), makemul(a, b)...) + Mul{mul_t(a, b)}(makemul(mul_t(a, b), a, b)...) end end @@ -1513,7 +1890,7 @@ end ### Div ### -/(a::Union{SN,Number}, b::SN) = Div(a, b) +/(a::Union{SN,Number}, b::SN) = Div(a, b, false) *(a::SN, b::Number) = b * a @@ -1539,10 +1916,10 @@ function ^(a::SN, b) # fast path 1 elseif b isa Real && b < 0 - Div(1, a ^ (-b)) + Div(1, a ^ (-b), false) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.coeff, b) - Mul(promote_symtype(^, symtype(a), symtype(b)), + coeff = ^(a.coeff, b) + Mul{promote_symtype(^, symtype(a), symtype(b))}( coeff, mapvalues((k, v) -> b*v, a.dict)) else Pow(a, b) diff --git a/src/utils.jl b/src/utils.jl index 8dcf23b8f..a7a13423d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,30 +1,3 @@ -const TIMER_OUTPUTS = true -const being_timed = Ref{Bool}(false) - -if TIMER_OUTPUTS - using TimerOutputs - - macro timer(name, expr) - :(if being_timed[] - @timeit $(esc(name)) $(esc(expr)) - else - $(esc(expr)) - end) - end - - macro iftimer(expr) - esc(expr) - end - -else - macro timer(name, expr) - esc(expr) - end - - macro iftimer(expr) - end -end - using Base: ImmutableDict @@ -42,13 +15,13 @@ function has_trig_exp(term) if Base.@nany 9 i->fns[i] === op return true else - return any(has_trig_exp, arguments(term)) + return any(has_trig_exp, parent(arguments(term))) end end function fold(t) if iscall(t) - tt = map(fold, arguments(t)) + tt = map(fold, parent(arguments(t))) if !any(x->x isa Symbolic, tt) # evaluate it return operation(t)(tt...) @@ -68,8 +41,18 @@ isliteral(::Type{T}) where {T} = x -> x isa T is_literal_number(x) = isliteral(Number)(x) # checking the type directly is faster than dynamic dispatch in type unstable code -_iszero(x) = x isa Number && iszero(x) -_isone(x) = x isa Number && isone(x) +function _iszero(x) + x = unwrap(x) + x isa Number && return iszero(x) + x isa Array && return iszero(x) + return false +end +function _isone(x) + x = unwrap(x) + x isa Number && return isone(x) + x isa Array && return isone(x) + return false +end _isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x) <: Integer) _isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real) diff --git a/test/basics.jl b/test/basics.jl index 5d097aa14..b37073a98 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, isequal_with_metadata +using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term using SymbolicUtils using ConstructionBase: setproperties using Setfield @@ -117,11 +117,11 @@ end @testset "Base methods" begin @syms w::Complex z::Complex a::Real b::Real x - @test isequal(w + z, Add(Complex, 0, Dict(w=>1, z=>1))) - @test isequal(z + a, Add(Number, 0, Dict(z=>1, a=>1))) - @test isequal(a + b, Add(Real, 0, Dict(a=>1, b=>1))) - @test isequal(a + x, Add(Number, 0, Dict(a=>1, x=>1))) - @test isequal(a + z, Add(Number, 0, Dict(a=>1, z=>1))) + @test isequal(w + z, Add{Complex}(0, Dict(w=>1, z=>1))) + @test isequal(z + a, Add{Number}(0, Dict(z=>1, a=>1))) + @test isequal(a + b, Add{Real}(0, Dict(a=>1, b=>1))) + @test isequal(a + x, Add{Number}(0, Dict(a=>1, x=>1))) + @test isequal(a + z, Add{Number}(0, Dict(a=>1, z=>1))) foo(w, z, a, b) = 1.0 SymbolicUtils.promote_symtype(::typeof(foo), args...) = Real @@ -226,7 +226,7 @@ end # test that the "x^2 + y^-1 + sin(a)^3.5 + 2t + 1//1" expression from Symbolics.jl/build_targets.jl is properly sorted @syms x1 y1 a1 t1 - @test repr(x1^2 + y1^-1 + sin(a1)^3.5 + 2t1 + 1//1) == "(1//1) + 2t1 + 1 / y1 + x1^2 + sin(a1)^3.5" + @test repr(x1^2 + y1^-1 + sin(a1)^3.5 + 2t1 + 1//1) == "1 + 2t1 + 1 / y1 + x1^2 + sin(a1)^3.5" end @testset "inspect" begin @@ -303,8 +303,8 @@ toterm(t) = Term{symtype(t)}(operation(t), arguments(t)) @testset "diffs" begin @syms a b c @test isequal(toterm(-1c), Term{Number}(*, [-1, c])) - @test isequal(toterm(-1(a+b)), Term{Number}(+, [-1a, -b])) - @test isequal(toterm((a + b) - (b + c)), Term{Number}(+, [a, -1c])) + @test isequal(toterm(-1(a+b)), Term{Number}(+, [-b, -a])) + @test isequal(toterm((a + b) - (b + c)), Term{Number}(+, [a, -c])) end @testset "hash" begin @@ -346,9 +346,11 @@ end a1 = setmetadata(a, Ctx1, "meta_1") a2 = setmetadata(a, Ctx1, "meta_1") a3 = setmetadata(a, Ctx2, "meta_2") - @test !isequal_with_metadata(a, a1) - @test isequal_with_metadata(a1, a2) - @test !isequal_with_metadata(a1, a3) + SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true begin + @test !isequal(a, a1) + @test isequal(a1, a2) + @test !isequal(a1, a3) + end end @testset "subtyping" begin @@ -437,8 +439,12 @@ end x = setmetadata(x(t), Int, 3) ex = x * y res = substitute(ex, Dict(y => 1)) - @test SymbolicUtils.isequal_with_metadata(res, x) + SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true begin + @test isequal(res, x) + end ex = x + y res = substitute(ex, Dict(y => 0)) - @test SymbolicUtils.isequal_with_metadata(res, x) + SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true begin + @test isequal(res, x) + end end diff --git a/test/code.jl b/test/code.jl index f30700660..24f055d32 100644 --- a/test/code.jl +++ b/test/code.jl @@ -20,11 +20,11 @@ nanmath_st.rewrites[:nanmath] = true @test toexpr(a*b*c*d*e) == :($(*)($(*)($(*)($(*)(a, b), c), d), e)) @test toexpr(a+b+c+d+e) == :($(+)($(+)($(+)($(+)(a, b), c), d), e)) @test toexpr(a+b) == :($(+)(a, b)) - @test toexpr(x(t)+y(t)) == :($(+)(x(t), y(t))) - @test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x(t), y(t)), x($(+)(1, t)))) + @test toexpr(x(t)+y(t)) == :($(+)(y(t), x(t))) + @test toexpr(x(t)+y(t)+x(t+1)) == :($(+)($(+)(x($(+)(1, t)), y(t)), x(t))) s = LazyState() Code.union_rewrites!(s.rewrites, [x(t), y(t)]) - @test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)($(+)(var"x(t)", var"y(t)"), x($(+)(1, t)))) + @test toexpr(x(t)+y(t)+x(t+1), s) == :($(+)($(+)(x($(+)(1, t)), var"y(t)"), var"x(t)")) ex = :(let a = 3, b = $(+)(1,a) $(+)(a, b) @@ -38,7 +38,7 @@ nanmath_st.rewrites[:nanmath] = true test_repr(toexpr(Func([x(t), x],[b ← a+2, y(t) ← b], x(t)+x(t+1)+b+y(t))), :(function (var"x(t)", x; b = $(+)(2, a), var"y(t)" = b) - $(+)($(+)($(+)(b, var"x(t)"), var"y(t)"), x($(+)(1, t))) + $(+)($(+)($(+)(b, x($(+)(1, t))), var"y(t)"), var"x(t)") end)) test_repr(toexpr(Func([DestructuredArgs([x, x(t)], :state), DestructuredArgs((a, b), :params)], [], @@ -49,7 +49,7 @@ nanmath_st.rewrites[:nanmath] = true var"x(t)" = state[2] a = params[1] b = params[2] - $(+)($(+)($(+)(a, b), var"x(t)"), x($(+)(1, t))) + $(+)($(+)($(+)(a, b), x($(+)(1, t))), var"x(t)") end end)) @@ -58,7 +58,7 @@ nanmath_st.rewrites[:nanmath] = true x(t+1) + x(t) + a + b)), :(function (state, params) begin - $(+)($(+)($(+)(params[1], params[2]), state[2]), state[1]($(+)(1, t))) + $(+)($(+)($(+)(params[1], params[2]), state[1]($(+)(1, t))), state[2]) end end)) diff --git a/test/cse.jl b/test/cse.jl index d450cae8b..6bfd9914e 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -19,28 +19,28 @@ end expr = sin(a + b) * (a + b) sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 3 - @test isequal(sorted_nodes[1].rhs, term(+, a, b)) + @test isequal(sorted_nodes[1].rhs, term(+, b, a)) @test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs) expr = (a + b)^(a + b) sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 2 - @test isequal(sorted_nodes[1].rhs, term(+, a, b)) + @test isequal(sorted_nodes[1].rhs, term(+, b, a)) ab_node = sorted_nodes[1].lhs @test isequal(term(^, ab_node, ab_node), sorted_nodes[2].rhs) let_expr = cse(expr) @test length(let_expr.pairs) == 2 - @test isequal(let_expr.pairs[1].rhs, term(+, a, b)) + @test isequal(let_expr.pairs[1].rhs, term(+, b, a)) corresponding_sym = let_expr.pairs[1].lhs @test isequal(let_expr.pairs[end].rhs, term(^, corresponding_sym, corresponding_sym)) expr = a + b sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 1 - @test isequal(sorted_nodes[1].rhs, term(+, a, b)) + @test isequal(sorted_nodes[1].rhs, term(+, b, a)) let_expr = cse(expr) @test length(let_expr.pairs) == 1 - @test isequal(let_expr.pairs[end].rhs, term(+, a, b)) + @test isequal(let_expr.pairs[end].rhs, term(+, b, a)) expr = a sorted_nodes = topological_sort(expr) diff --git a/test/hash_consing.jl b/test/hash_consing.jl index 919f19b91..80d154d32 100644 --- a/test/hash_consing.jl +++ b/test/hash_consing.jl @@ -1,31 +1,33 @@ using SymbolicUtils, Test -using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2, metadata, BasicSymbolic, Symbolic, - isequal_with_metadata +using SymbolicUtils: Term, Add, Mul, Div, Pow, metadata, BasicSymbolic, Symbolic import TermInterface +hash2(a) = SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true hash(a) +isequal2(a, b) = SymbolicUtils.@manually_scope SymbolicUtils.COMPARE_FULL => true isequal(a, b) + struct Ctx1 end struct Ctx2 end @testset "Sym" begin x1 = only(@syms x) x2 = only(@syms x) - @test x1 === x2 + @test x1.id === x2.id x3 = only(@syms x::Float64) - @test x1 !== x3 + @test x1.id !== x3.id x4 = only(@syms x::Float64) - @test x1 !== x4 - @test x3 === x4 + @test x1.id !== x4.id + @test x3.id === x4.id x5 = only(@syms x::Int) x6 = only(@syms x::Int) - @test x1 !== x5 - @test x3 !== x5 - @test x5 === x6 + @test x1.id !== x5.id + @test x3.id !== x5.id + @test x5.id === x6.id xm1 = setmetadata(x1, Ctx1, "meta_1") xm2 = setmetadata(x1, Ctx1, "meta_1") - @test xm1 === xm2 + @test xm1.id === xm2.id xm3 = setmetadata(x1, Ctx2, "meta_2") - @test xm1 !== xm3 + @test xm1.id !== xm3.id end @syms a b c @@ -33,82 +35,73 @@ end @testset "Term" begin t1 = sin(a) t2 = sin(a) - @test t1 === t2 + @test t1.id === t2.id t3 = Term(identity,[a]) t4 = Term(identity,[a]) - @test t3 === t4 + @test t3.id === t4.id t5 = Term{Int}(identity,[a]) - @test t3 !== t5 + @test t3.id !== t5.id tm1 = setmetadata(t1, Ctx1, "meta_1") - @test t1 !== tm1 + @test t1.id !== tm1.id end @testset "Add" begin d1 = a + b d2 = b + a - @test d1 === d2 + @test d1.id === d2.id d3 = b - 2 + a d4 = a + b - 2 - @test d3 === d4 - d5 = Add(Int, 0, Dict(a => 1, b => 1)) - @test d5 !== d1 + @test d3.id === d4.id + d5 = Add{Int}(0, Dict(a => 1, b => 1)) + @test d5.id !== d1.id dm1 = setmetadata(d1,Ctx1,"meta_1") - @test d1 !== dm1 + @test d1.id !== dm1.id end @testset "Mul" begin m1 = a*b m2 = b*a - @test m1 === m2 + @test m1.id === m2.id m3 = 6*a*b m4 = 3*a*2*b - @test m3 === m4 - m5 = Mul(Int, 1, Dict(a => 1, b => 1)) - @test m5 !== m1 + @test m3.id === m4.id + m5 = Mul{Int}(1, Dict(a => 1, b => 1)) + @test m5.id !== m1.id mm1 = setmetadata(m1, Ctx1, "meta_1") - @test m1 !== mm1 + @test m1.id !== mm1.id end @testset "Div" begin v1 = a/b v2 = a/b - @test v1 === v2 + @test v1.id === v2.id v3 = -1/a v4 = -1/a - @test v3 === v4 + @test v3.id === v4.id v5 = 3a/6 v6 = 2a/4 - @test v5 === v6 - v7 = Div{Float64}(-1,a) - @test v7 !== v3 + @test v5.id === v6.id + v7 = Div{Float64}(-1,a, false) + @test v7.id !== v3.id vm1 = setmetadata(v1,Ctx1, "meta_1") - @test vm1 !== v1 + @test vm1.id !== v1.id end @testset "Pow" begin p1 = a^b p2 = a^b - @test p1 === p2 + @test p1.id === p2.id p3 = a^(2^-b) p4 = a^(2^-b) - @test p3 === p4 + @test p3.id === p4.id p5 = Pow{Float64}(a,b) - @test p1 !== p5 + @test p1.id !== p5.id pm1 = setmetadata(p1,Ctx1, "meta_1") - @test pm1 !== p1 -end - -@testset "Equivalent numbers" begin - f = 0.5 - r = 1 // 2 - @test hash(f) == hash(r) - u0 = zero(UInt) - @test hash2(f, u0) != hash2(r, u0) - @test f + a !== r + a + @test pm1.id !== p1.id end @testset "Symbolics in metadata" begin @@ -116,8 +109,8 @@ end a1 = setmetadata(a, Int, b) b1 = setmetadata(b, Int, 3) a2 = setmetadata(a, Int, b1) - @test a1 !== a2 - @test !SymbolicUtils.isequal_with_metadata(a1, a2) + @test a1.id !== a2.id + @test !isequal2(a1, a2) @test metadata(metadata(a1)[Int]) === nothing @test metadata(metadata(a2)[Int])[Int] == 3 end @@ -125,10 +118,10 @@ end @testset "Compare metadata of expression tree" begin @syms a b aa = setmetadata(a, Int, b) - @test aa !== a + @test aa.id !== a.id @test isequal(a, aa) - @test !SymbolicUtils.isequal_with_metadata(a, aa) - @test !SymbolicUtils.isequal_with_metadata(2a, 2aa) + @test !isequal2(a, aa) + @test !isequal2(2a, 2aa) end @testset "Hashconsing can be toggled" begin @@ -136,14 +129,14 @@ end @syms a b x1 = a + b x2 = a + b - @test x1 !== x2 + @test x1.id === nothing === x2.id SymbolicUtils.ENABLE_HASHCONSING[] = true end @testset "`hash2` is cached" begin @syms a b f(..) for ex in [a + b, a * b, f(a)] - h = SymbolicUtils.hash2(ex) + h = hash2(ex) @test h == ex.hash2[] ex2 = setmetadata(ex, Int, 3) @test ex2.hash2[] != h @@ -164,41 +157,27 @@ Base.isequal(a::MySymbolic, b::MySymbolic) = isequal(a.sym, b.sym) @syms x::Real xx = setmetadata(x, Int, 3) @test isequal(x, xx) - @test !isequal_with_metadata(x, xx) + @test !isequal2(x, xx) myx = MySymbolic(x) myxx = MySymbolic(xx) @test isequal(myx, myxx) - @test !isequal_with_metadata(myx, myxx) + @test !isequal2(myx, myxx) ex = 2x exx = 2xx myex = MySymbolic(ex) myexx = MySymbolic(exx) @test isequal(ex, exx) - @test !isequal_with_metadata(ex, exx) + @test !isequal2(ex, exx) @test isequal(myex, myexx) - @test !isequal_with_metadata(myex, myexx) + @test !isequal2(myex, myexx) t = Term{Real}(max, Any[x, myex]) tt = Term{Real}(max, Any[xx, myexx]) @test isequal(t, tt) - @test !isequal_with_metadata(t, tt) + @test !isequal2(t, tt) myt = MySymbolic(t) mytt = MySymbolic(tt) @test isequal(myt, mytt) - @test !isequal_with_metadata(myt, mytt) -end - -@testset "`isequal_with_metadata` ensures numbers have the same type" begin - @syms x - tmp1 = x ^ 3.0 - tmp2 = x ^ 3 - @test !SymbolicUtils.isequal_with_metadata(tmp1, tmp2) - @test arguments(tmp1)[2] isa Float64 - @test arguments(tmp2)[2] isa Int - tmp1 = 2tmp1 - tmp2 = 2tmp2 - @test !SymbolicUtils.isequal_with_metadata(tmp1, tmp2) - @test arguments(arguments(tmp1)[2])[2] isa Float64 - @test arguments(arguments(tmp2)[2])[2] isa Int + @test !isequal2(myt, mytt) end diff --git a/test/inspect_output/ex-md.txt b/test/inspect_output/ex-md.txt index e61a807d2..52bb27ded 100644 --- a/test/inspect_output/ex-md.txt +++ b/test/inspect_output/ex-md.txt @@ -1,20 +1,20 @@ - 1 DIV - 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) - 3 │ ├─ POW - 4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 1 Div + 2 ├─ AddOrMul(variant = "MUL", scalar = 1, powers = (1 + 2x + 3y => 2, z => 1)) + 3 │ ├─ Pow + 4 │ │ ├─ AddOrMul(variant = "ADD", scalar = 1, coeffs = (x => 2, y => 3)) 5 │ │ │ ├─ 1 - 6 │ │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 6 │ │ │ ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 7 │ │ │ │ ├─ 2 - 8 │ │ │ │ └─ SYM(x) - 9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,)) + 8 │ │ │ │ └─ Sym(x) + 9 │ │ │ └─ AddOrMul(variant = "MUL", scalar = 3, powers = (y => 1,)) 10 │ │ │ ├─ 3 -11 │ │ │ └─ SYM(y) metadata=(Integer => 42,) +11 │ │ │ └─ Sym(y) metadata=(Integer => 42,) 12 │ │ └─ 2 -13 │ └─ SYM(z) -14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) -15 ├─ MUL(scalar = 2, powers = (x => 1,)) +13 │ └─ Sym(z) +14 └─ AddOrMul(variant = "ADD", scalar = 0, coeffs = (x => 2, z => 1)) +15 ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 16 │ ├─ 2 -17 │ └─ SYM(x) -18 └─ SYM(z) +17 │ └─ Sym(x) +18 └─ Sym(z) Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/ex-nohint.txt b/test/inspect_output/ex-nohint.txt index d1a782902..87515e6a8 100644 --- a/test/inspect_output/ex-nohint.txt +++ b/test/inspect_output/ex-nohint.txt @@ -1,18 +1,18 @@ - 1 DIV - 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) - 3 │ ├─ POW - 4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 1 Div + 2 ├─ AddOrMul(variant = "MUL", scalar = 1, powers = (1 + 2x + 3y => 2, z => 1)) + 3 │ ├─ Pow + 4 │ │ ├─ AddOrMul(variant = "ADD", scalar = 1, coeffs = (x => 2, y => 3)) 5 │ │ │ ├─ 1 - 6 │ │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 6 │ │ │ ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 7 │ │ │ │ ├─ 2 - 8 │ │ │ │ └─ SYM(x) - 9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,)) + 8 │ │ │ │ └─ Sym(x) + 9 │ │ │ └─ AddOrMul(variant = "MUL", scalar = 3, powers = (y => 1,)) 10 │ │ │ ├─ 3 -11 │ │ │ └─ SYM(y) +11 │ │ │ └─ Sym(y) 12 │ │ └─ 2 -13 │ └─ SYM(z) -14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) -15 ├─ MUL(scalar = 2, powers = (x => 1,)) +13 │ └─ Sym(z) +14 └─ AddOrMul(variant = "ADD", scalar = 0, coeffs = (x => 2, z => 1)) +15 ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 16 │ ├─ 2 -17 │ └─ SYM(x) -18 └─ SYM(z) \ No newline at end of file +17 │ └─ Sym(x) +18 └─ Sym(z) \ No newline at end of file diff --git a/test/inspect_output/ex.txt b/test/inspect_output/ex.txt index 2762a84e8..e55fd00a7 100644 --- a/test/inspect_output/ex.txt +++ b/test/inspect_output/ex.txt @@ -1,20 +1,20 @@ - 1 DIV - 2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2)) - 3 │ ├─ POW - 4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2)) + 1 Div + 2 ├─ AddOrMul(variant = "MUL", scalar = 1, powers = (1 + 2x + 3y => 2, z => 1)) + 3 │ ├─ Pow + 4 │ │ ├─ AddOrMul(variant = "ADD", scalar = 1, coeffs = (x => 2, y => 3)) 5 │ │ │ ├─ 1 - 6 │ │ │ ├─ MUL(scalar = 2, powers = (x => 1,)) + 6 │ │ │ ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 7 │ │ │ │ ├─ 2 - 8 │ │ │ │ └─ SYM(x) - 9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,)) + 8 │ │ │ │ └─ Sym(x) + 9 │ │ │ └─ AddOrMul(variant = "MUL", scalar = 3, powers = (y => 1,)) 10 │ │ │ ├─ 3 -11 │ │ │ └─ SYM(y) +11 │ │ │ └─ Sym(y) 12 │ │ └─ 2 -13 │ └─ SYM(z) -14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2)) -15 ├─ MUL(scalar = 2, powers = (x => 1,)) +13 │ └─ Sym(z) +14 └─ AddOrMul(variant = "ADD", scalar = 0, coeffs = (x => 2, z => 1)) +15 ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 16 │ ├─ 2 -17 │ └─ SYM(x) -18 └─ SYM(z) +17 │ └─ Sym(x) +18 └─ Sym(z) Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/sub10.txt b/test/inspect_output/sub10.txt index f651d4312..6e167ca1a 100644 --- a/test/inspect_output/sub10.txt +++ b/test/inspect_output/sub10.txt @@ -1,5 +1,5 @@ -1 MUL(scalar = 3, powers = (y => 1,)) +1 AddOrMul(variant = "MUL", scalar = 3, powers = (y => 1,)) 2 ├─ 3 -3 └─ SYM(y) +3 └─ Sym(y) Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/inspect_output/sub14.txt b/test/inspect_output/sub14.txt index f49f4fceb..12eaa25db 100644 --- a/test/inspect_output/sub14.txt +++ b/test/inspect_output/sub14.txt @@ -1,7 +1,7 @@ -1 ADD(scalar = 0, coeffs = (z => 1, x => 2)) -2 ├─ MUL(scalar = 2, powers = (x => 1,)) +1 AddOrMul(variant = "ADD", scalar = 0, coeffs = (x => 2, z => 1)) +2 ├─ AddOrMul(variant = "MUL", scalar = 2, powers = (x => 1,)) 3 │ ├─ 2 -4 │ └─ SYM(x) -5 └─ SYM(z) +4 │ └─ Sym(x) +5 └─ Sym(z) Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number \ No newline at end of file diff --git a/test/order.jl b/test/order.jl index c6e78f2cb..c7c3a134c 100644 --- a/test/order.jl +++ b/test/order.jl @@ -59,9 +59,9 @@ end @test istotal(ρ(), -1z()) - @syms a(t) b(t) t - @test a(t) <ₑ b(t) - @test !(b(t) <ₑ a(t)) + @syms b(t) a(t) t + @test b(t) <ₑ a(t) + @test !(a(t) <ₑ b(t)) @syms y() x() @test x() <ₑ y() diff --git a/test/rewrite.jl b/test/rewrite.jl index b8996f79e..8860cb223 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,5 @@ using SymbolicUtils - +using Test include("utils.jl") @syms a b c @@ -38,9 +38,9 @@ end @test @rule((~x)^(~x) => ~x)(b^a) === nothing @test @rule((~x)^(~x) => ~x)(a+a) === nothing @eqtest @rule((~x)^(~x) => ~x)(sin(a)^sin(a)) == sin(a) - @eqtest @rule((~x*~y + ~x*~z) => ~x * (~y+~z))(a*b + a*c) == a*(b+c) + @eqtest @rule((~x*~y + ~z*~x) => ~x * (~y+~z))(a*b + a*c) == a*(b+c) - @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] + @test issetequal(@rule(+(~~x) => ~~x)(a + b), [a,b]) @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8) @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8]) @@ -49,15 +49,16 @@ end @testset "Slot matcher with default value" begin r_sum = @rule (~x + ~!y)^2 => ~y - @test r_sum((a + b)^2) === b + @test r_sum((a + b)^2) in Set([a, b]) @test r_sum(b^2) === 0 r_mult = @rule ~x * ~!y => ~y - @test r_mult(a * b) === b + @test r_mult(a * b) in Set([a, b]) @test r_mult(a) === 1 r_mult2 = @rule (~x * ~!y + ~z) => ~y - @test r_mult2(c + a*b) === b + # can match either `a` or `b` or coefficient of `c` + @test r_mult2(c + a*b) in Set([1, a, b]) @test r_mult2(c + b) === 1 # here the "normal part" in the defslot_term_matcher is not a symbol but a tree @@ -67,7 +68,7 @@ end @test r_mult3(c+2) === 1 r_pow = @rule (~x)^(~!m) => ~m - @test r_pow(a^(b+1)) === b+1 + @test isequal(r_pow(a^(b+1)), b+1) @test r_pow(a) === 1 @test r_pow(a+1) === 1 @@ -77,8 +78,8 @@ end @test r_pow2(a+b) === 1 r_mix = @rule (~x + (~y)*(~!c))^(~!m) => ~m + ~c - @test r_mix((a + b*c)^2) === 2 + c - @test r_mix((a + b*c)) === 1 + c + @test r_mix((a + b*c)^2) in Set([2 + b, 3, 2 + c]) + @test r_mix((a + b*c)) in Set([1 + b, 1 + c]) @test r_mix((a + b)) === 2 #1+1 end @@ -129,7 +130,7 @@ end ex = setmetadata(ex, MetaData, :metadata) ex1 = ex + b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a * b ex = setmetadata(ex, MetaData, :metadata) @@ -142,5 +143,5 @@ end ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata -end \ No newline at end of file + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata +end diff --git a/test/rulesets.jl b/test/rulesets.jl index cfbc0143e..03de4067b 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -163,22 +163,15 @@ pred(x) = error("Fail") @test sprint(io -> Base.showerror(io, err)) == "Failed to apply rule ~x + ~(y::pred) => ~x on expression a + b" end -@testset "Threading" begin - @syms a b c d - ex = (((0.6666666666666666 / (c / 1)) + ((1 * a) / (c / 1))) + - (1.0 / (((1 * d) / (1 + b)) * (1 / b)))) + - ((((1 * a) + (1 * a)) / ((2.0 * (d + 1)) / 1.0)) + - ((((d * 1) / (1 + c)) * 2.0) / ((1 / d) + (1 / c)))) - @eqtest simplify(ex) == simplify(ex, threaded=true, thread_subtree_cutoff=3) - @test SymbolicUtils.node_count(a + b * c / d) == 7 -end - -@testset "timerwrite" begin - @syms a b c d - expr1 = foldr((x, y) -> rand([*, /])(x, y), rand([a, b, c, d], 100)) - SymbolicUtils.@timerewrite simplify(expr1) -end - +# @testset "Threading" begin +# @syms a b c d +# ex = (((0.6666666666666666 / (c / 1)) + ((1 * a) / (c / 1))) + +# (1.0 / (((1 * d) / (1 + b)) * (1 / b)))) + +# ((((1 * a) + (1 * a)) / ((2.0 * (d + 1)) / 1.0)) + +# ((((d * 1) / (1 + c)) * 2.0) / ((1 / d) + (1 / c)))) +# @eqtest simplify(ex) == simplify(ex, threaded=true, thread_subtree_cutoff=3) +# @test SymbolicUtils.node_count(a + b * c / d) == 7 +# end _g(y) = sin @testset "interpolation" begin