diff --git a/Project.toml b/Project.toml index 8f971f133..a9d033ee4 100644 --- a/Project.toml +++ b/Project.toml @@ -13,9 +13,11 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -25,7 +27,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" [weakdeps] LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" @@ -45,8 +46,10 @@ ConstructionBase = "1.5.7" DataStructures = "0.18" DocStringExtensions = "0.8, 0.9" DynamicPolynomials = "0.5, 0.6" +Expronicon = "~0.8" IfElse = "0.1" LabelledArrays = "1.5" +MLStyle = "0.4" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1" ReverseDiff = "1" @@ -56,7 +59,6 @@ StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" TermInterface = "2.0" TimerOutputs = "0.5" -Unityper = "0.1.2" julia = "1.3" [extras] diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 047bef71a..feac062ec 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need @rule(+(~~xs) => ~~xs)(x + y + z) # output -3-element view(::Vector{Any}, 1:3) with eltype Any: +3-element view(::Vector{SymbolicUtils.BasicSymbolic}, 1:3) with eltype SymbolicUtils.BasicSymbolic: z y x diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 2bf52507e..49fb84a7d 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -7,7 +7,6 @@ using DocStringExtensions export @syms, term, showraw, hasmetadata, getmetadata, setmetadata -using Unityper using TermInterface using DataStructures using Setfield @@ -23,6 +22,10 @@ import ArrayInterface Base.@deprecate istree iscall export istree, operation, arguments, sorted_arguments, similarterm, iscall + +using Base: RefValue +using Expronicon.ADT: @adt +using MLStyle: @match # Sym, Term, # Add, Mul and Pow include("types.jl") diff --git a/src/code.jl b/src/code.jl index 4128a39fd..b9c8ec9d1 100644 --- a/src/code.jl +++ b/src/code.jl @@ -8,8 +8,8 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters -import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, sorted_arguments, metadata, isterm, term, maketerm +import SymbolicUtils: @matchable, BasicSymbolic, _Sym, Term, iscall, operation, arguments, issym, + isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm, get_val import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -182,6 +182,8 @@ function toexpr(O, st) if issym(O) O = substitute_name(O, st) return issym(O) ? nameof(O) : toexpr(O, st) + elseif isconst(O) + return toexpr(get_val(O), st) end O = substitute_name(O, st) @@ -681,7 +683,7 @@ end ### Common subexprssion evaluation -@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) +@inline newsym(::Type{T}) where T = _Sym(T, gensym("cse")) function _cse!(mem, expr) iscall(expr) || return expr @@ -745,7 +747,7 @@ function cse_block!(assignments, counter, names, name, state, x) if haskey(names, x) return names[x] else - sym = Sym{symtype(x)}(Symbol(name, counter[])) + sym = _Sym(symtype(x), Symbol(name, counter[])) names[x] = sym push!(assignments, sym ← x) counter[] += 1 diff --git a/src/inspect.jl b/src/inspect.jl index ab3951725..715082496 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -6,22 +6,23 @@ function AbstractTrees.nodevalue(x::Symbolic) end function AbstractTrees.nodevalue(x::BasicSymbolic) - str = if !iscall(x) + str = if issym(x) string(exprtype(x), "(", x, ")") + elseif isconst(x) + string(get_val(x)) elseif isadd(x) - string(exprtype(x), - (scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) + string(exprtype(x), + (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x)))) elseif ismul(x) string(exprtype(x), - (scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) + (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in get_dict(x)))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else - string(exprtype(x),"{", operation(x), "}") + string(exprtype(x), "{", operation(x), "}") end - if inspect_metadata[] && !isnothing(metadata(x)) - str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x))) + str *= string(" metadata=", Tuple(k => v for (k, v) in metadata(x))) end Text(str) end diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..03d7b33f3 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,9 +6,23 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) - iscall(val) && return term_matcher(val) + if isconst(val) + slot = get_val(val) + return matcher(slot) + elseif iscall(val) + return term_matcher(val) + end function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing + if islist(data) + cd = car(data) + if isconst(cd) + cd = get_val(cd) + end + if isequal(cd, val) + return next(bindings, 1) + end + end + nothing end end diff --git a/src/methods.jl b/src/methods.jl index 2baef6424..f6c90ae52 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -179,16 +179,22 @@ for (f, Domain) in [(==) => Number, (!=) => Number, xor => Bool] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b; type = Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b; type = Bool) + (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b; type = Bool) end end for f in [!, ~] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool - (::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s]) + function (::$(typeof(f)))(s::Symbolic{Bool}) + if isconst(s) + s = get_val(s) + return !s + end + _Term(Bool, !, [s]) + end end end @@ -196,7 +202,7 @@ end # An ifelse node, ifelse is a built-in unfortunately # So this uses IfElse.jl's ifelse that we imported function ifelse(_if::Symbolic{Bool}, _then, _else) - Term{Union{symtype(_then), symtype(_else)}}(ifelse, Any[_if, _then, _else]) + _Term(Union{symtype(_then), symtype(_else)}, ifelse, Any[_if, _then, _else]) end promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T, S} diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..623b0cecd 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -27,7 +27,7 @@ function get_degrees(expr) elseif iscall(expr) op = operation(expr) args = sorted_arguments(expr) - if op == (^) && args[2] isa Number + if op == (^) && (args[2] isa Number || (isconst(args[2]) && get_val(args[2]) isa Number)) return map(get_degrees(args[1])) do (base, pow) (base => pow * args[2]) end @@ -79,12 +79,23 @@ function <ₑ(a::Tuple, b::Tuple) end function <ₑ(a::BasicSymbolic, b::BasicSymbolic) + aisconst = isconst(a) + if aisconst + a = get_val(a) + end + bisconst = isconst(b) + if bisconst + b = get_val(b) + end + if aisconst || bisconst + return a <ₑ b + end da, db = get_degrees(a), get_degrees(b) fw = monomial_lt(da, db) bw = monomial_lt(db, da) if fw === bw && !isequal(a, b) if _arglen(a) == _arglen(b) - return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,) + return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...) else return _arglen(a) < _arglen(b) end diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..b18d6bd47 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -95,6 +95,9 @@ end _isone(p::PolyForm) = isone(p.p) function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) + if isconst(x) + x = get_val(x) + end if x isa Number return x elseif iscall(x) @@ -129,7 +132,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) name = Symbol(string(op), "_", hash(y)) @label lookup - sym = Sym{symtype(x)}(name) + sym = _Sym(symtype(x), name) if haskey(sym2term, sym) if isequal(sym2term[sym][1], x) return local_polyize(sym) @@ -262,10 +265,10 @@ end function polyform_factors(d, pvar2sym, sym2term) make(xs) = map(xs) do x - if ispow(x) && x.exp isa Integer && x.exp > 0 + if ispow(x) && get_exp(x) isa Integer && get_exp(x) > 0 # here we do want to recurse one level, that's why it's wrong to just # use Fs = Union{typeof(+), typeof(*)} here. - Pow(PolyForm(x.base, pvar2sym, sym2term), x.exp) + _Pow(PolyForm(get_base(x), pvar2sym, sym2term), get_exp(x)) else PolyForm(x, pvar2sym, sym2term) end @@ -277,13 +280,13 @@ end _mul(xs...) = all(isempty, xs) ? 1 : *(Iterators.flatten(xs)...) function simplify_div(d) - d.simplified && return d + d.impl.simplified[] && return d ns, ds = polyform_factors(d, get_pvar2sym(), get_sym2term()) ns, ds = rm_gcds(ns, ds) if all(_isone, ds) return isempty(ns) ? 1 : simplify_fractions(_mul(ns)) else - Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds))) + _Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds))) end end @@ -293,11 +296,11 @@ end #add_divs(x, y) = x + y function add_divs(x, y) if isdiv(x) && isdiv(y) - return (x.num * y.den + y.num * x.den) / (x.den * y.den) + return (get_num(x) * get_den(y) + get_num(y) * get_den(x)) / (get_den(x) * get_den(y)) elseif isdiv(x) - return (x.num + y * x.den) / x.den + return (get_num(x) + y * get_den(x)) / get_den(x) elseif isdiv(y) - return (x * y.den + y.num) / y.den + return (x * get_den(y) + get_num(y)) / get_den(y) else x + y end @@ -381,7 +384,7 @@ function fraction_isone(x) end function needs_div_rules(x) - (isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) || + (isdiv(x) && !(get_num(x) isa Number) && !(get_den(x) isa Number)) || (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) || (iscall(x) && any(needs_div_rules, arguments(x))) end @@ -413,13 +416,13 @@ But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`. Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) - if ispow(d) && isdiv(d.base) - return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp)) + if ispow(d) && isdiv(get_base(d)) + return quick_cancel((get_num(get_base(d))^get_exp(d)) / (get_den(get_base(d))^get_exp(d))) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) - num, den = quick_cancel(d.num, d.den) - return Div(num, den) + num, den = quick_cancel(get_num(d), get_den(d)) + return _Div(num, den) else return d end @@ -449,20 +452,29 @@ end # ispow(x) case function quick_pow(x, y) - x.exp isa Number || return (x, y) - isequal(x.base, y) && x.exp >= 1 ? (Pow{symtype(x)}(x.base, x.exp - 1),1) : (x, y) + ximpl = x.impl + if !isa(ximpl.exp, Number) + x, y + elseif isequal(ximpl.base, y) && ximpl.exp >= 1 + _Pow(symtype(x), ximpl.base, ximpl.exp - 1), 1 + else + x, y + end end # Double Pow case function quick_powpow(x, y) - if isequal(x.base, y.base) - !(x.exp isa Number && y.exp isa Number) && return (x, y) - if x.exp > y.exp - return Pow{symtype(x)}(x.base, x.exp-y.exp), 1 - elseif x.exp == y.exp + ximpl = x.impl + yimpl = y.impl + if isequal(ximpl.base, yimpl.base) + if !(ximpl.exp isa Number && yimpl.exp isa Number) + return x, y + elseif ximpl.exp > yimpl.exp + return _Pow(symtype(x), ximpl.base, ximpl.exp - yimpl.exp), 1 + elseif ximpl.exp == yimpl.exp return 1, 1 else # x.exp < y.exp - return 1, Pow{symtype(y)}(y.base, y.exp-x.exp) + return 1, _Pow(symtype(y), yimpl.base, yimpl.exp - ximpl.exp) end end return x, y @@ -470,8 +482,10 @@ end # ismul(x) function quick_mul(x, y) - if haskey(x.dict, y) && x.dict[y] >= 1 - d = copy(x.dict) + ximpl = x.impl + xdict = ximpl.dict + if haskey(xdict, y) && xdict[y] >= 1 + d = copy(xdict) if d[y] > 1 d[y] -= 1 elseif d[y] == 1 @@ -479,8 +493,7 @@ function quick_mul(x, y) else error("Can't reach") end - - return Mul(symtype(x), x.coeff, d), 1 + return _Mul(symtype(x), ximpl.coeff, d), 1 else return x, y end @@ -488,20 +501,20 @@ end # mul, pow case function quick_mulpow(x, y) - y.exp isa Number || return (x, y) - if haskey(x.dict, y.base) - d = copy(x.dict) - if x.dict[y.base] > y.exp - d[y.base] -= y.exp + get_exp(y) isa Number || return (x, y) + if haskey(get_dict(x), get_base(y)) + d = copy(get_dict(x)) + if get_dict(x)[get_base(y)] > get_exp(y) + d[get_base(y)] -= get_exp(y) den = 1 - elseif x.dict[y.base] == y.exp - delete!(d, y.base) + elseif get_dict(x)[get_base(y)] == get_exp(y) + delete!(d, get_base(y)) den = 1 else - den = Pow{symtype(y)}(y.base, y.exp-d[y.base]) - delete!(d, y.base) + den = _Pow(symtype(y), get_base(y), get_exp(y)-d[get_base(y)]) + delete!(d, get_base(y)) end - return Mul(symtype(x), x.coeff, d), den + return _Mul(symtype(x), get_coeff(x), d), den else return x, y end @@ -509,8 +522,10 @@ end # Double mul case function quick_mulmul(x, y) - num_dict, den_dict = _merge_div(x.dict, y.dict) - Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict) + ximpl = x.impl + yimpl = y.impl + num_dict, den_dict = _merge_div(ximpl.dict, yimpl.dict) + _Mul(symtype(x), ximpl.coeff, num_dict), _Mul(symtype(y), yimpl.coeff, den_dict) end function _merge_div(ndict, ddict) diff --git a/src/rule.jl b/src/rule.jl index 5de0aa79c..7a6353a6b 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -67,10 +67,10 @@ function makepattern(expr, keys) makeslot(expr.args[2], keys) end else - :(term($(map(x->makepattern(x, keys), expr.args)...); type=Any)) + :(term($(map(x -> makepattern(x, keys), expr.args)...); type = Any)) end elseif expr.head === :ref - :(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any)) + :(term(getindex, $(map(x -> makepattern(x, keys), expr.args)...); type = Any)) elseif expr.head === :$ return esc(expr.args[1]) else @@ -404,7 +404,7 @@ function (acr::ACRule)(term) itr = acr.sets(eachindex(args), acr.arity) for inds in itr - result = r(Term{T}(f, @views args[inds])) + result = r(_Term(T, f, @views args[inds])) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result diff --git a/src/substitute.jl b/src/substitute.jl index 828f88b14..3df62cf37 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -22,6 +22,9 @@ function substitute(expr, dict; fold=true) canfold = !(op isa Symbolic) args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) + if isconst(x) + x′ = get_val(x′) + end canfold = canfold && !(x′ isa Symbolic) x′ end @@ -54,10 +57,13 @@ function _occursin(needle, haystack) if iscall(haystack) args = arguments(haystack) for arg in args + if isconst(arg) + arg = get_val(arg) + end if needle isa Integer || needle isa AbstractFloat isequal(needle, arg) && return true else - occursin(needle, arg) && return true + occursin(needle, arg) && return true end end end diff --git a/src/types.jl b/src/types.jl index 683f58d44..779a259a4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,85 +1,105 @@ -#------------------- -#-------------------- -#### Symbolic -#-------------------- abstract type Symbolic{T} end -### -### Uni-type design -### - -@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV +@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV CONST -const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} +const Metadata = Union{Nothing, Base.ImmutableDict{DataType, Any}} const NO_METADATA = nothing +const EMPTY_HASH = UInt(0) -sdict(kv...) = Dict{Any, Any}(kv...) - -using Base: RefValue -const EMPTY_ARGS = [] -const EMPTY_HASH = RefValue(UInt(0)) -const NOT_SORTED = RefValue(false) -const EMPTY_DICT = sdict() -const EMPTY_DICT_T = typeof(EMPTY_DICT) - -@compactify show_methods=false begin - @abstract struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA +@adt BasicSymbolicImpl begin + struct Sym + name::Symbol end - struct Sym{T} <: BasicSymbolic{T} - name::Symbol = :OOF + struct Term + f::Any + arguments::Vector{Symbolic} end - struct Term{T} <: BasicSymbolic{T} - f::Any = identity # base/num if Pow; issorted if Add/Dict - arguments::Vector{Any} = EMPTY_ARGS - hash::RefValue{UInt} = EMPTY_HASH + struct Add + coeff::Any + dict::Dict{BasicSymbolic, Any} + arguments::Vector{BasicSymbolic} = BasicSymbolic[] + issorted::RefValue{Bool} = Ref(false) end - struct Mul{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED + struct Mul + coeff::Any + dict::Dict{BasicSymbolic, Any} + arguments::Vector{BasicSymbolic} = BasicSymbolic[] + issorted::RefValue{Bool} = Ref(false) end - struct Add{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED + struct Div + num::Any + den::Any + simplified::RefValue{Bool} = Ref(false) + arguments::Vector{Any} = [num, den] end - struct Div{T} <: BasicSymbolic{T} - num::Any = 1 - den::Any = 1 - simplified::Bool = false - arguments::Vector{Any} = EMPTY_ARGS + struct Pow + base::Any + exp::Any + arguments::Vector{Any} = [base, exp] end - struct Pow{T} <: BasicSymbolic{T} - base::Any = 1 - exp::Any = 1 - arguments::Vector{Any} = EMPTY_ARGS + struct Const + val::Any end end +@kwdef struct BasicSymbolic{T} <: Symbolic{T} + impl::BasicSymbolicImpl + metadata::Metadata = NO_METADATA + hash::RefValue{UInt} = Ref(EMPTY_HASH) +end + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => TERM - Add => ADD - Mul => MUL - Div => DIV - Pow => POW - Sym => SYM - _ => error_on_type() + @match x.impl begin + Sym(_...) => SYM + Term(_...) => TERM + Add(_...) => ADD + Mul(_...) => MUL + Div(_...) => DIV + Pow(_...) => POW + Const(_...) => CONST end end +function get_name(x::BasicSymbolic) + x.impl.name +end + +function get_coeff(x::BasicSymbolic) + x.impl.coeff +end + +function get_dict(x::BasicSymbolic) + x.impl.dict +end + +function get_num(x::BasicSymbolic) + x.impl.num +end + +function get_den(x::BasicSymbolic) + x.impl.den +end + +function get_base(x::BasicSymbolic) + x.impl.base +end + +function get_exp(x::BasicSymbolic) + x.impl.exp +end + +function get_val(x::BasicSymbolic) + x.impl.val +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") +@noinline error_const() = error("Const doesn't have a operation or arguments!") @noinline error_property(E, s) = error("$E doesn't have field $s") # We can think about bits later @@ -89,10 +109,22 @@ const SIMPLIFIED = 0x01 << 0 #@inline is_of_type(x::BasicSymbolic, type::UInt8) = (x.bitflags & type) != 0x00 #@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) -function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T - nt = getproperties(obj) - nt_new = merge(nt, patch) - Unityper.rt_constructor(obj){T}(;nt_new...) +function ConstructionBase.setproperties( + obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where {T} + nt1 = getproperties(obj) + nt2 = getproperties(obj.impl) + nt1 = merge(nt1, patch) + nt2 = merge(nt2, patch) + metadata = nt1.metadata + @match obj.impl begin + Sym(_...) => _Sym(T, nt2.name; metadata) + Term(_...) => _Term(T, nt2.f, nt2.arguments; metadata) + Add(_...) => _Add(T, nt2.coeff, nt2.dict; metadata) + Mul(_...) => _Mul(T, nt2.coeff, nt2.dict; metadata) + Div(_...) => _Div(T, nt2.num, nt2.den; metadata) + Pow(_...) => _Pow(T, nt2.base, nt2.exp; metadata) + Const(_...) => _Const(nt2.val; metadata) + end end ### @@ -114,14 +146,14 @@ symtype(x) = typeof(x) # We're returning a function pointer @inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f - Add => (+) - Mul => (*) - Div => (/) - Pow => (^) - Sym => error_sym() - _ => error_on_type() + @match x.impl begin + Term(_...) => x.impl.f + Add(_...) => (+) + Mul(_...) => (*) + Div(_...) => (/) + Pow(_...) => (^) + Sym(_...) => error_sym() + Const(_...) => error_const() end end @@ -129,22 +161,23 @@ end function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) - @compactified x::BasicSymbolic begin - Add => @goto ADD - Mul => @goto MUL - _ => return args + impl = x.impl + @match impl begin + Add(_...) => @goto ADD + Mul(_...) => @goto MUL + _ => return args end @label MUL - if !x.issorted[] - sort!(args, by=get_degrees) - x.issorted[] = true + if !impl.issorted[] + sort!(args, by = get_degrees) + impl.issorted[] = true end return args @label ADD - if !x.issorted[] - sort!(args, lt = monomial_lt, by=get_degrees) - x.issorted[] = true + if !impl.issorted[] + sort!(args, lt = monomial_lt, by = get_degrees) + impl.issorted[] = true end return args end @@ -154,54 +187,49 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => return x.arguments - Add => @goto ADDMUL - Mul => @goto ADDMUL - Div => @goto DIV - Pow => @goto POW - Sym => error_sym() - _ => error_on_type() + impl = x.impl + @match impl begin + Term(_...) => return impl.arguments + Add(_...) => @goto ADDMUL + Mul(_...) => @goto ADDMUL + Div(_...) => @goto DIVPOW + Pow(_...) => @goto DIVPOW + Sym(_...) => error_sym() + Const(_...) => error_const() end @label ADDMUL E = exprtype(x) - args = x.arguments + args = impl.arguments isempty(args) || return args - siz = length(x.dict) - idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff) + siz = length(impl.dict) + idcoeff = E === ADD ? _iszero(impl.coeff) : _isone(impl.coeff) sizehint!(args, idcoeff ? siz : siz + 1) - idcoeff || push!(args, x.coeff) + idcoeff || push!(args, impl.coeff) if isadd(x) - for (k, v) in x.dict - push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v], nothing)) + for (k, v) in impl.dict + push!(args, applicable(*, k, v) ? k * v : maketerm(k, *, [k, v], nothing)) end else # MUL - for (k, v) in x.dict + for (k, v) in impl.dict push!(args, unstable_pow(k, v)) end end return args - @label DIV - args = x.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.num) - push!(args, x.den) + @label DIVPOW + args = impl.arguments return args +end - @label POW - args = x.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.base) - push!(args, x.exp) - return args +function isexpr(x::BasicSymbolic) + @match x.impl begin + Sym(_...) => false + Const(_...) => false + _ => true + end end -isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false @@ -212,12 +240,54 @@ iscall(s::BasicSymbolic) = isexpr(s) Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined on `x` and must return a `Symbol`. """ -issym(x) = isa_SymType(Val(:Sym), x) -isterm(x) = isa_SymType(Val(:Term), x) -ismul(x) = isa_SymType(Val(:Mul), x) -isadd(x) = isa_SymType(Val(:Add), x) -ispow(x) = isa_SymType(Val(:Pow), x) -isdiv(x) = isa_SymType(Val(:Div), x) +function issym(x) + isa(x, BasicSymbolic) && @match x.impl begin + Sym(_...) => true + _ => false + end +end + +function isterm(x) + isa(x, BasicSymbolic) && @match x.impl begin + Term(_...) => true + _ => false + end +end + +function isadd(x) + isa(x, BasicSymbolic) && @match x.impl begin + Add(_...) => true + _ => false + end +end + +function ismul(x) + isa(x, BasicSymbolic) && @match x.impl begin + Mul(_...) => true + _ => false + end +end + +function ispow(x) + isa(x, BasicSymbolic) && @match x.impl begin + Pow(_...) => true + _ => false + end +end + +function isdiv(x) + isa(x, BasicSymbolic) && @match x.impl begin + Div(_...) => true + _ => false + end +end + +function isconst(x) + isa(x, BasicSymbolic) && @match x.impl begin + Const(_...) => true + _ => false + end +end ### ### Base interface @@ -251,15 +321,17 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict) + coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b)) elseif E === DIV - isequal(a.num, b.num) && isequal(a.den, b.den) + isequal(get_num(a), get_num(b)) && isequal(get_den(a), get_den(b)) elseif E === POW - isequal(a.exp, b.exp) && isequal(a.base, b.base) + isequal(get_exp(a), get_exp(b)) && isequal(get_base(a), get_base(b)) elseif E === TERM a1 = arguments(a) a2 = arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) + elseif E === CONST + isequal(get_val(a), get_val(b)) else error_on_type() end @@ -268,7 +340,13 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +function Base.nameof(s::BasicSymbolic) + if issym(s) + get_name(s) + else + error("None Sym BasicSymbolic doesn't have a name") + end +end ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) @@ -277,6 +355,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt +const COS_SALT = 0xdc3d6b8f18b75e3c % UInt function Base.hash(s::BasicSymbolic, salt::UInt)::UInt E = exprtype(s) if E === SYM @@ -286,13 +365,13 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt))) + h′ = hash(hashoffset, hash(get_coeff(s), hash(get_dict(s), salt))) s.hash[] = h′ return h′ elseif E === DIV - return hash(s.num, hash(s.den, salt ⊻ DIV_SALT)) + return hash(get_num(s), hash(get_den(s), salt ⊻ DIV_SALT)) elseif E === POW - hash(s.exp, hash(s.base, salt ⊻ POW_SALT)) + hash(get_exp(s), hash(get_base(s), salt ⊻ POW_SALT)) elseif E === TERM !iszero(salt) && return hash(hash(s, zero(UInt)), salt) h = s.hash[] @@ -302,6 +381,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h′ = hashvec(arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ + elseif E === CONST + return hash(get_val(s), salt ⊻ COS_SALT) else error_on_type() end @@ -311,51 +392,79 @@ end ### Constructors ### -function Sym{T}(name::Symbol; kw...) where T - Sym{T}(; name=name, kw...) +function _Sym(::Type{T}, name::Symbol; kwargs...) where {T} + impl = Sym(name) + BasicSymbolic{T}(; impl, kwargs...) end -function Term{T}(f, args; kw...) where T - if eltype(args) !== Any - args = convert(Vector{Any}, args) +function _Term(::Type{T}, f, args; kwargs...) where {T} + if eltype(args) !== Symbolic + args = convert(Vector{Symbolic}, args) end + impl = Term(f, args) + BasicSymbolic{T}(; impl, kwargs...) +end +function _Term(f, args; kwargs...) + _Term(_promote_symtype(f, args), f, args; kwargs...) +end + +function _Const(val::T; kwargs...) where {T} + impl = Const(val) + BasicSymbolic{T}(; impl, kwargs...) +end - Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) +function Base.convert(::Type{Symbolic}, x) + _Const(x) end -function Term(f, args; metadata=NO_METADATA) - Term{_promote_symtype(f, args)}(f, args, metadata=metadata) +function Base.convert(::Type{BasicSymbolic}, x) + _Const(x) +end +function Base.convert(::Type{BasicSymbolic}, x::BasicSymbolic) + x end -function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T +function _Add(::Type{T}, coeff, dict; kwargs...) where {T} if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 - k,v = first(dict) + k, v = first(dict) if _isone(v) return k else coeff, dict = makemul(v, k) - return Mul(T, coeff, dict) + return _Mul(T, coeff, dict) end end - - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + impl = Add(; coeff, dict) + BasicSymbolic{T}(; impl, kwargs...) end -function Mul(T, a, b; metadata=NO_METADATA, kw...) - isempty(b) && return a - if _isone(a) && length(b) == 1 - pair = first(b) +function _Mul(::Type{T}, coeff, dict; kwargs...) where {T} + isempty(dict) && return coeff + if _isone(coeff) && length(dict) == 1 + pair = first(dict) if _isone(last(pair)) # first value return first(pair) else return unstable_pow(first(pair), last(pair)) end - else - coeff = a - dict = b - Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + end + impl = Mul(; coeff, dict) + BasicSymbolic{T}(; impl, kwargs...) +end + +function _iszero(x::BasicSymbolic) + @match x.impl begin + Const(_...) => iszero(get_val(x)) + _ => false + end +end + +function _isone(x::BasicSymbolic) + @match x.impl begin + Const(_...) => isone(get_val(x)) + _ => false end end @@ -363,7 +472,7 @@ const Rat = Union{Rational, Integer} function ratcoeff(x) if ismul(x) - ratcoeff(x.coeff) + ratcoeff(get_coeff(x)) elseif x isa Rat (true, x) else @@ -374,108 +483,115 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) - if x.coeff isa Rational && isone(x.coeff.den) - Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false)) + coeff = get_coeff(x) + if coeff isa Rational && isone(denominator(coeff)) + _Mul(symtype(x), coeff.num, get_dict(x); metadata = x.metadata) else x end elseif x isa Rational - isone(x.den) ? x.num : x + isone(denominator(x)) ? numerator(x) : x else x end end -function Div{T}(n, d, simplified=false; metadata=nothing) where {T} - if T<:Number && !(T<:SafeReal) - n, d = quick_cancel(n, d) +function _Div(::Type{T}, num, den; kwargs...) where {T} + if T <: Number && !(T <: SafeReal) + num, den = quick_cancel(num, den) end - _iszero(n) && return zero(typeof(n)) - _isone(d) && return n - - if isdiv(n) && isdiv(d) - return Div{T}(n.num * d.den, n.den * d.num) - elseif isdiv(n) - return Div{T}(n.num, n.den * d) - elseif isdiv(d) - return Div{T}(n * d.den, d.num) + _iszero(num) && return zero(typeof(num)) + _isone(den) && return num + if isdiv(num) && isdiv(den) + return _Div(T, get_num(num) * get_den(den), get_den(num) * get_num(den)) + elseif isdiv(num) + return _Div(T, get_num(num), get_den(num) * den) + elseif isdiv(den) + return _Div(T, num * get_den(den), get_num(den)) + end + if den isa Number && _isone(-den) + return -1 * num + end + if num isa Rat && den isa Rat + return num // den # maybe called by oblivious code in simplify end - - d isa Number && _isone(-d) && return -1 * n - n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify # GCD coefficient upon construction - rat, nc = ratcoeff(n) + rat, nc = ratcoeff(num) if rat - rat, dc = ratcoeff(d) + rat, dc = ratcoeff(den) if rat g = gcd(nc, dc) * sign(dc) # make denominator positive invdc = ratio(1, g) - n = maybe_intcoeff(invdc * n) - d = maybe_intcoeff(invdc * d) - if d isa Number - _isone(d) && return n - _isone(-d) && return -1 * n + num = maybe_intcoeff(invdc * num) + den = maybe_intcoeff(invdc * den) + if den isa Number + if _isone(den) + return num + end + if _isone(-den) + return -1 * num + end end end end - - Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + impl = Div(; num, den) + BasicSymbolic{T}(; impl, kwargs...) end - -function Div(n,d, simplified=false; kw...) - Div{promote_symtype((/), symtype(n), symtype(d))}(n, d, simplified; kw...) +function _Div(num, den; kwargs...) + _Div(promote_symtype((/), symtype(num), symtype(den)), num, den; kwargs...) end @inline function numerators(x) - isdiv(x) && return numerators(x.num) + isdiv(x) && return numerators(get_num(x)) iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] end -@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] +@inline denominators(x) = isdiv(x) ? numerators(get_den(x)) : Any[1] -function (::Type{<:Pow{T}})(a, b; metadata=NO_METADATA) where {T} - _iszero(b) && return 1 - _isone(b) && return a - Pow{T}(; base=a, exp=b, arguments=[], metadata) +function _Pow(::Type{T}, base, exp; kwargs...) where {T} + _iszero(exp) && return 1 + _isone(exp) && return base + impl = Pow(; base, exp) + BasicSymbolic{T}(; impl, kwargs...) end - -function Pow(a, b; metadata=NO_METADATA) - Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata) +function _Pow(base, exp; kwargs...) + _Pow(promote_symtype(^, symtype(base), symtype(exp)), makepow(base, exp)..., kwargs...) end -function toterm(t::BasicSymbolic{T}) where T +function toterm(t::BasicSymbolic{T}) where {T} E = exprtype(t) if E === SYM || E === TERM return t elseif E === ADD || E === MUL - args = Any[] - push!(args, t.coeff) - for (k, coeff) in t.dict - push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), Any[coeff, k])) + args = BasicSymbolic[] + push!(args, get_coeff(t)) + for (k, coeff) in get_dict(t) + push!( + args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) end - Term{T}(operation(t), args) + _Term(T, operation(t), args) elseif E === DIV - Term{T}(/, Any[t.num, t.den]) + _Term(T, /, [get_num(t), get_den(t)]) elseif E === POW - Term{T}(^, [t.base, t.exp]) + _Term(T, ^, [get_base(t), get_exp(t)]) else error_on_type() end end -""" - makeadd(sign, coeff::Number, xs...) +"""" +$(SIGNATURES) Any Muls inside an Add should always have a coeff of 1 and the key (in Add) should instead be used to store the actual coefficient """ function makeadd(sign, coeff, xs...) - d = sdict() + d = Dict{BasicSymbolic, Any}() for x in xs if isadd(x) - coeff += x.coeff - _merge!(+, d, x.dict, filter=_iszero) + coeff += get_coeff(x) + _merge!(+, d, get_dict(x), filter = _iszero) continue end if x isa Number @@ -483,8 +599,8 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = Mul(symtype(x), 1, x.dict) - v = sign * x.coeff + get(d, k, 0) + k = _Mul(symtype(x), 1, get_dict(x)) + v = sign * get_coeff(x) + get(d, k, 0) else k = x v = sign + get(d, x, 0) @@ -498,15 +614,15 @@ function makeadd(sign, coeff, xs...) coeff, d end -function makemul(coeff, xs...; d=sdict()) +function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) for x in xs - if ispow(x) && x.exp isa Number - d[x.base] = x.exp + get(d, x.base, 0) + if ispow(x) && get_exp(x) isa Number + d[get_base(x)] = get_exp(x) + get(d, get_base(x), 0) elseif x isa Number coeff *= x elseif ismul(x) - coeff *= x.coeff - _merge!(+, d, x.dict, filter=_iszero) + coeff *= get_coeff(x) + _merge!(+, d, get_dict(x), filter = _iszero) else v = 1 + get(d, x, 0) if _iszero(v) @@ -516,45 +632,43 @@ function makemul(coeff, xs...; d=sdict()) end end end - (coeff, d) + coeff, d end -unstable_pow(a, b) = a isa Integer && b isa Integer ? (a//1) ^ b : a ^ b +unstable_pow(a, b) = a isa Integer && b isa Integer ? (a // 1)^b : a^b function makepow(a, b) base = a exp = b if ispow(a) - base = a.base - exp = a.exp * b + base = get_base(a) + exp = get_exp(a) * b end - return (base, exp) + base, exp end function term(f, args...; type = nothing) if type === nothing - T = _promote_symtype(f, args) - else - T = type + type = _promote_symtype(f, args) end - Term{T}(f, Any[args...]) + _Term(type, f, [args...]) end """ - unflatten(t::Symbolic{T}) +$(TYPEDSIGNATURES) + Binarizes `Term`s with n-ary operations """ -function unflatten(t::Symbolic{T}) where{T} +function unflatten(t::Symbolic{T}) where {T} if iscall(t) f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops a = arguments(t) - return foldl((x,y) -> Term{T}(f, Any[x, y]), a) + return foldl((x, y) -> _Term(T, f, [x, y]), a) end end return t end - unflatten(t) = t function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) @@ -587,7 +701,7 @@ function basicsymbolic(f, args, stype, metadata) end if T <: LiteralReal @goto FALLBACK - elseif all(x->symtype(x) <: Number, args) + elseif all(x -> symtype(x) <: Number, args) if f === (+) res = +(args...) if isadd(res) || isterm(res) @@ -608,7 +722,7 @@ function basicsymbolic(f, args, stype, metadata) end res elseif f == (^) && length(args) == 2 - res = args[1] ^ args[2] + res = args[1]^args[2] if ispow(res) @set! res.metadata = metadata end @@ -618,7 +732,7 @@ function basicsymbolic(f, args, stype, metadata) end else @label FALLBACK - Term{T}(f, args, metadata=metadata) + _Term(T, f, args; metadata) end end @@ -632,19 +746,38 @@ function hasmetadata(s::Symbolic, ctx) metadata(s) isa AbstractDict && haskey(metadata(s), ctx) end -issafecanon(f, s) = true -function issafecanon(f, s::Symbolic) - if isnothing(metadata(s)) || issym(s) - return true - else - _issafecanon(f, s) +""" +$(TYPEDSIGNATURES) + +Check if the symbolic expression(s) is/are safe to canonicalize with respect to the function `f`. + +This function determines if applying the canonicalization rules associated with function `f` +to the symbolic expression `s` is safe and won't lead to incorrect simplifications. It handles various cases +depending on the type of `s` and the function `f`. + +For multiple arguments, `issafecanon(f, ss...)`, it checks if canonicalization is safe for all expressions in `ss`. + +# Arguments +- `f`: The function for which canonicalization safety is being checked. +- `s`: The symbolic expression to check. +- `ss...`: A variable number of symbolic expressions to check. + +# Returns +- `true` if canonicalization is safe, `false` otherwise. +""" +function issafecanon(f, s::BasicSymbolic) + isnothing(metadata(s)) || @match s.impl begin + Sym(_...) => true + Const(_...) => true + _ => _issafecanon(f, s) end end -_issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+,*,^)) -_issafecanon(::typeof(+), s) = !iscall(s) || !(operation(s) in (+,*)) -_issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) +issafecanon(f, s) = true +issafecanon(f, ss...) = all(x -> issafecanon(f, x), ss) -issafecanon(f, ss...) = all(x->issafecanon(f, x), ss) +_issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+, *, ^)) +_issafecanon(::typeof(+), s) = !iscall(s) || !(operation(s) in (+, *)) +_issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) function getmetadata(s::Symbolic, ctx) md = metadata(s) @@ -703,6 +836,10 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) + if isconst(t) + val = get_val(t) + return isnegative(val) + end if iscall(t) && operation(t) === (*) coeff = first(arguments(t)) return isnegative(coeff) @@ -711,7 +848,7 @@ function isnegative(t) end # Term{} -setargs(t, args) = Term{symtype(t)}(operation(t), args) +setargs(t, args) = _Term(symtype(t), operation(t), args) cdrargs(args) = setargs(t, cdr(args)) print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")") @@ -737,8 +874,12 @@ function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) args = arguments(t) - @assert args[1] < 0 - Any[-args[1], args[2:end]...] + arg1 = args[1] + if isconst(arg1) + arg1 = get_val(arg1) + end + @assert arg1 < 0 + Any[-arg1, args[2:end]...] end @@ -773,17 +914,27 @@ function show_pow(io, args) end function show_mul(io, args) + if isconst(args) + print(io, get_val(args)) + return + end length(args) == 1 && return print_arg(io, *, args[1]) - minus = args[1] isa Number && args[1] == -1 - unit = args[1] isa Number && args[1] == 1 + arg1 = args[1] + if isconst(arg1) + arg1 = get_val(arg1) + end + + minus = arg1 isa Number && arg1 == -1 + unit = arg1 isa Number && arg1 == 1 - paren_scalar = (args[1] isa Complex && !_iszero(imag(args[1]))) || + paren_scalar = (arg1 isa Complex && !_iszero(imag(arg1))) || args[1] isa Rational || - (args[1] isa Number && !isfinite(args[1])) + (arg1 isa Number && !isfinite(arg1)) nostar = minus || unit || - (!paren_scalar && args[1] isa Number && !(args[2] isa Number)) + (!paren_scalar && arg1 isa Number && + !(isconst(args[2]) && get_val(args[2]) isa Number)) for (i, t) in enumerate(args) if i != 1 @@ -872,10 +1023,10 @@ showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) - if issym(v) - Base.show_unquoted(io, v.name) - else - show_term(io, v) + @match v.impl begin + Sym(_...) => Base.show_unquoted(io, get_name(v)) + Const(_...) => print(io, get_val(v)) + _ => show_term(io, v) end end @@ -914,7 +1065,7 @@ promote_symtype(f, Ts...) = Any struct FnType{X<:Tuple,Y,Z} end -(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, Any[args...]) +(f::Symbolic{<:FnType})(args...) = _Term(promote_symtype(f, symtype.(args)...), f, [args...]) function (f::Symbolic)(args...) error("Sym $f is not callable. " * @@ -955,18 +1106,16 @@ end function _promote_symtype(f, args) if issym(f) promote_symtype(f, map(symtype, args)...) + elseif length(args) == 0 + promote_symtype(f) + elseif length(args) == 1 + promote_symtype(f, symtype(args[1])) + elseif length(args) == 2 + promote_symtype(f, symtype(args[1]), symtype(args[2])) + elseif isassociative(f) + mapfoldl(symtype, (x, y) -> promote_symtype(f, x, y), args) else - if length(args) == 0 - promote_symtype(f) - elseif length(args) == 1 - promote_symtype(f, symtype(args[1])) - elseif length(args) == 2 - promote_symtype(f, symtype(args[1]), symtype(args[2])) - elseif isassociative(f) - mapfoldl(symtype, (x,y) -> promote_symtype(f, x, y), args) - else - promote_symtype(f, map(symtype, args)...) - end + promote_symtype(f, map(symtype, args)...) end end @@ -1001,10 +1150,10 @@ macro syms(xs...) nt = _name_type(x) n, t = nt.name, nt.type T = esc(t) - :($(esc(n)) = Sym{$T}($(Expr(:quote, n)))) + :($(esc(n)) = _Sym($T, $(Expr(:quote, n)))) end Expr(:block, defs..., - :(tuple($(map(x->esc(_name_type(x).name), xs)...)))) + :(tuple($(map(x -> esc(_name_type(x).name), xs)...)))) end function syms_syntax_error() @@ -1089,141 +1238,147 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) - !issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata + if isconst(a) + return get_val(a) + b + end + if isconst(b) + return get_val(b) + a + end + !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) - return Add(add_t(a,b), - a.coeff + b.coeff, - _merge(+, a.dict, b.dict, filter=_iszero)) + return _Add( + add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero)) + return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, get_dict(a), dict, filter = _iszero)) elseif isadd(b) return b + a end coeff, dict = makeadd(1, 0, a, b) - Add(add_t(a,b), coeff, dict) + _Add(add_t(a, b), coeff, dict) end - function +(a::Number, b::SN) + if isconst(b) + return a + get_val(b) + end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - Add(add_t(a,b), a + b.coeff, b.dict) + _Add(add_t(a, b), a + get_coeff(b), get_dict(b)) else - Add(add_t(a,b), makeadd(1, a, b)...) + _Add(add_t(a, b), makeadd(1, a, b)...) end end - +(a::SN, b::Number) = b + a - +(a::SN) = a function -(a::SN) - !issafecanon(*, a) && return term(-, a) - isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) : - Add(sub_t(a), makeadd(-1, 0, a)...) + if isconst(a) + v = get_val(a) + mv = -v + return _Const(mv) + end + if !issafecanon(*, a) + return term(-, a) + end + if isadd(a) + _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, get_dict(a))) + else + _Add(sub_t(a), makeadd(-1, 0, a)...) + end end - function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) - isadd(a) && isadd(b) ? Add(sub_t(a,b), - a.coeff - b.coeff, - _merge(-, a.dict, - b.dict, - filter=_iszero)) : a + (-b) + if isadd(a) && isadd(b) + _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, get_dict(a), get_dict(b), filter = _iszero)) + else + a + (-b) + end end - -(a::Number, b::SN) = a + (-b) -(a::SN, b::Number) = a + (-b) - -mul_t(a,b) = promote_symtype(*, symtype(a), symtype(b)) +mul_t(a, b) = promote_symtype(*, symtype(a), symtype(b)) mul_t(a) = promote_symtype(*, symtype(a)) -*(a::SN) = a - function *(a::SN, b::SN) + if isconst(a) + return get_val(a) * b + end + if isconst(b) + return get_val(b) * a + end # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) - Div(a.num * b.num, a.den * b.den) + _Div(get_num(a) * get_num(b), get_den(a) * get_den(b)) elseif isdiv(a) - Div(a.num * b, a.den) + _Div(get_num(a) * b, get_den(a)) elseif isdiv(b) - Div(a * b.num, b.den) + _Div(a * get_num(b), get_den(b)) elseif ismul(a) && ismul(b) - Mul(mul_t(a, b), - a.coeff * b.coeff, - _merge(+, a.dict, b.dict, filter=_iszero)) + _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), + _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif ismul(a) && ispow(b) - if b.exp isa Number - Mul(mul_t(a, b), - a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero)) + if get_exp(b) isa Number + _Mul(mul_t(a, b), + get_coeff(a), + _merge(+, get_dict(a), Base.ImmutableDict(get_base(b) => get_exp(b)), + filter = _iszero)) else - Mul(mul_t(a, b), - a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero)) + _Mul(mul_t(a, b), get_coeff(a), + _merge(+, get_dict(a), Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) b * a else - Mul(mul_t(a,b), makemul(1, a, b)...) + _Mul(mul_t(a, b), makemul(1, a, b)...) end end - function *(a::Number, b::SN) + if isconst(b) + return a * get_val(b) + end !issafecanon(*, b) && return term(*, a, b) if iszero(a) a elseif isone(a) b elseif isdiv(b) - Div(a*b.num, b.den) + _Div(a * get_num(b), get_den(b)) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - Add(T, b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) + _Add(T, get_coeff(b) * a, + Dict{BasicSymbolic, Any}(k => v * a for (k, v) in get_dict(b))) else - Mul(mul_t(a, b), makemul(a, b)...) + _Mul(mul_t(a, b), makemul(a, b)...) end end - -### -### Div -### - -/(a::Union{SN,Number}, b::SN) = Div(a, b) - *(a::SN, b::Number) = b * a +*(a::SN) = a -\(a::SN, b::Union{Number, SN}) = b / a - -\(a::Number, b::SN) = b / a - -/(a::SN, b::Number) = (isone(abs(b)) ? b : (b isa Integer ? 1//b : inv(b))) * a +/(a::Union{SN, Number}, b::SN) = _Div(a, b) +/(a::SN, b::Number) = (isone(abs(b)) ? b : (b isa Integer ? 1 // b : inv(b))) * a //(a::Union{SN, Number}, b::SN) = a / b - //(a::SN, b::T) where {T <: Number} = (one(T) // b) * a - -### -### Pow -### +\(a::SN, b::Union{Number, SN}) = b / a +\(a::Number, b::SN) = b / a function ^(a::SN, b) - !issafecanon(^, a,b) && return Pow(a, b) + !issafecanon(^, a, b) && return Pow(a, b) if b isa Number && iszero(b) - # fast path 1 elseif b isa Number && b < 0 - Div(1, a ^ (-b)) + _Div(1, a^(-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.coeff, b) - Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b*v, a.dict)) + coeff = unstable_pow(get_coeff(a), b) + _Mul(promote_symtype(^, symtype(a), symtype(b)), + coeff, mapvalues((k, v) -> b * v, get_dict(a))) else - Pow(a, b) + _Pow(a, b) end end - -^(a::Number, b::SN) = Pow(a, b) +^(a::Number, b::SN) = _Pow(a, b) diff --git a/src/utils.jl b/src/utils.jl index 812e229fb..3ced418da 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -64,8 +64,12 @@ end sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T -isliteral(::Type{T}) where {T} = x -> x isa T -is_literal_number(x) = isliteral(Number)(x) +function is_literal_number(x) + if isconst(x) + x = get_val(x) + end + x isa Number +end # checking the type directly is faster than dynamic dispatch in type unstable code _iszero(x) = x isa Number && iszero(x) @@ -179,10 +183,15 @@ Base.length(l::LL) = length(l.v)-l.i+1 @inline car(l::LL) = l.v[l.i] @inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i+1) -Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY -Base.isempty(t::Term) = false -@inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +function Base.length(t::BasicSymbolic) + @match t.impl begin + Term(_...) => length(arguments(t)) + 1 # PIRACY + _ => 1 + end +end +Base.isempty(t::BasicSymbolic) = false +@inline car(t::BasicSymbolic) = operation(t) +@inline cdr(t::BasicSymbolic) = arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) diff --git a/test/basics.jl b/test/basics.jl index 024533c9e..2161f1648 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,5 @@ -using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term +using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, + BasicSymbolic, term, get_name, get_coeff, get_dict, get_num, get_den using SymbolicUtils using IfElse: ifelse using Setfield @@ -9,17 +10,17 @@ using Test @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int @test issym(a) && symtype(a) == Number - @test a.name === :a + @test get_name(a) === :a @test issym(b) && symtype(b) == Float64 @test nameof(b) === :b @test issym(f) - @test f.name === :f + @test get_name(f) === :f @test symtype(f) == FnType{Tuple{Real}, Number, Nothing} @test issym(g) - @test g.name === :g + @test get_name(g) === :g @test symtype(g) == FnType{Tuple{Number, FnType{Tuple{Real}, Number, Nothing}}, Int, Nothing} @test isterm(f(b)) @@ -37,14 +38,14 @@ using Test @syms (f::typeof(max))(::Real, ::AbstractFloat)::Number a::Real @test issym(f) - @test f.name == :f + @test get_name(f) == :f @test symtype(f) == FnType{Tuple{Real, AbstractFloat}, Number, typeof(max)} @test isterm(f(a, b)) @test symtype(f(a, b)) == Number @syms g(p, (h::typeof(identity))(q::Real)::Number)::Number @test issym(g) - @test g.name == :g + @test get_name(g) == :g @test symtype(g) == FnType{Tuple{Number, FnType{Tuple{Real}, Number, typeof(identity)}}, Number, Nothing} @test_throws "not a subtype of" g(a, f) @syms (f::typeof(identity))(::Real)::Number @@ -108,42 +109,42 @@ struct Ctx2 end @test isequal(substitute(1+sqrt(a), Dict(a => 2), fold=false), - 1 + term(sqrt, 2, type=Number)) + 1 + term(sqrt, 2, type = Number)) @test substitute(1+sqrt(a), Dict(a => 2), fold=true) isa Float64 end @testset "Base methods" begin @syms w::Complex z::Complex a::Real b::Real x - @test isequal(w + z, Add(Complex, 0, Dict(w=>1, z=>1))) - @test isequal(z + a, Add(Number, 0, Dict(z=>1, a=>1))) - @test isequal(a + b, Add(Real, 0, Dict(a=>1, b=>1))) - @test isequal(a + x, Add(Number, 0, Dict(a=>1, x=>1))) - @test isequal(a + z, Add(Number, 0, Dict(a=>1, z=>1))) + @test isequal(w + z, _Add(Complex, 0, Dict(w => 1, z => 1))) + @test isequal(z + a, _Add(Number, 0, Dict(z => 1, a => 1))) + @test isequal(a + b, _Add(Real, 0, Dict(a => 1, b => 1))) + @test isequal(a + x, _Add(Number, 0, Dict(a => 1, x => 1))) + @test isequal(a + z, _Add(Number, 0, Dict(a => 1, z => 1))) foo(w, z, a, b) = 1.0 SymbolicUtils.promote_symtype(::typeof(foo), args...) = Real @test SymbolicUtils._promote_symtype(foo, (w, z, a, b,)) === Real # promote_symtype of identity - @test isequal(Term(identity, [w]), Term{Complex}(identity, [w])) + @test isequal(_Term(identity, [w]), _Term(Complex, identity, [w])) @test isequal(+(w), w) @test isequal(+(a), a) - @test isequal(rem2pi(a, RoundNearest), Term{Real}(rem2pi, [a, RoundNearest])) + @test isequal(rem2pi(a, RoundNearest), _Term(Real, rem2pi, [a, RoundNearest])) # bool for f in [(==), (!=), (<=), (>=), (<), (>)] - @test isequal(f(a, 0), Term{Bool}(f, [a, 0])) - @test isequal(f(0, a), Term{Bool}(f, [0, a])) - @test isequal(f(a, a), Term{Bool}(f, [a, a])) + @test isequal(f(a, 0), _Term(Bool, f, [a, 0])) + @test isequal(f(0, a), _Term(Bool, f, [0, a])) + @test isequal(f(a, a), _Term(Bool, f, [a, a])) end @test symtype(ifelse(true, 4, 5)) == Int @test symtype(ifelse(a < 0, b, w)) == Union{Real, Complex} @test SymbolicUtils.promote_symtype(ifelse, Bool, Int, Bool) == Union{Int, Bool} @test_throws MethodError w < 0 - @test isequal(w == 0, Term{Bool}(==, [w, 0])) + @test isequal(w == 0, _Term(Bool, ==, [w, 0])) @eqtest x // 5 == (1 // 5) * x @eqtest (1//2 * x) / 5 == (1 // 10) * x @@ -198,8 +199,8 @@ end @test repr((2a)^(-2a)) == "(2a)^(-2a)" @test repr(1/2a) == "1 / (2a)" @test repr(2/(2*a)) == "1 / a" - @test repr(Term(*, [1, 1])) == "1" - @test repr(Term(*, [2, 1])) == "2*1" + @test repr(_Term(*, [1, 1])) == "1" + @test repr(_Term(*, [2, 1])) == "2*1" @test repr((a + b) - (b + c)) == "a - c" @test repr(a + -1*(b + c)) == "a - b - c" @test repr(a + -1*b) == "a - b" @@ -215,7 +216,7 @@ end @test repr(2a+1+3a^2+2b+3b^2+4a*b) == "1 + 2a + 2b + 3(a^2) + 4a*b + 3(b^2)" @syms a b[1:3] c d[1:3] - get(x, i) = term(getindex, x, i, type=Number) + get(x, i) = term(getindex, x, i; type = Number) b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2) @test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]" @test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)" @@ -235,7 +236,7 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(get_dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype @@ -290,13 +291,13 @@ end @test symtype(new_expr) == Vector{Float64} end -toterm(t) = Term{symtype(t)}(operation(t), arguments(t)) +toterm(t) = _Term(symtype(t), operation(t), arguments(t)) @testset "diffs" begin @syms a b c - @test isequal(toterm(-1c), Term{Number}(*, [-1, c])) - @test isequal(toterm(-1(a+b)), Term{Number}(+, [-1a, -b])) - @test isequal(toterm((a + b) - (b + c)), Term{Number}(+, [a, -1c])) + @test isequal(toterm(-1c), _Term(Number, *, [-1, c])) + @test isequal(toterm(-1(a+b)), _Term(Number, +, [-1a, -b])) + @test isequal(toterm((a + b) - (b + c)), _Term(Number, +, [a, -1c])) end @testset "hash" begin @@ -337,31 +338,31 @@ end end @testset "subtyping" begin - T = FnType{Tuple{T,S,Int} where {T,S}, Real} - s = Sym{T}(:t) + T = FnType{Tuple{T, S, Int} where {T, S}, Real} + s = _Sym(T, :t) @syms a b c::Int @test isequal(arguments(s(a, b, c)), [a, b, c]) end @testset "div" begin @syms x::SafeReal y::Real - @test issym((2x/2y).num) - @test (2x/3y).num.coeff == 2 - @test (2x/3y).den.coeff == 3 - @test (2x/-3x).num.coeff == -2 - @test (2x/-3x).den.coeff == 3 - @test (2.5x/3x).num.coeff == 2.5 - @test (2.5x/3x).den.coeff == 3 - @test (x/3x).den.coeff == 3 + @test issym(get_num(2x / 2y)) + @test get_coeff(get_num(2x / 3y)) == 2 + @test get_coeff(get_den(2x / 3y)) == 3 + @test get_coeff(get_num(2x / -3x)) == -2 + @test get_coeff(get_den(2x / -3x)) == 3 + @test get_coeff(get_num(2.5x / 3x)) == 2.5 + @test get_coeff(get_den(2.5x / 3x)) == 3 + @test get_coeff(get_den(x / 3x)) == 3 @syms x y - @test issym((2x/2y).num) - @test (2x/3y).num.coeff == 2 - @test (2x/3y).den.coeff == 3 - @test (2x/-3x) == -2//3 - @test (2.5x/3x).num == 2.5 - @test (2.5x/3x).den == 3 - @test (x/3x) == 1//3 + @test issym(get_num(2x / 2y)) + @test get_coeff(get_num(2x / 3y)) == 2 + @test get_coeff(get_den(2x / 3y)) == 3 + @test (2x / -3x) == -2 // 3 + @test get_num(2.5x / 3x) == 2.5 + @test get_den(2.5x / 3x) == 3 + @test (x / 3x) == 1 // 3 @test isequal(x / 1, x) @test isequal(x / -1, -x) end diff --git a/test/code.jl b/test/code.jl index c05200167..40d8110cb 100644 --- a/test/code.jl +++ b/test/code.jl @@ -226,7 +226,7 @@ nanmath_st.rewrites[:nanmath] = true for q ∈ Base.Irrational[Base.MathConstants.catalan, Base.MathConstants.γ, π, Base.MathConstants.φ, ℯ, twoπ] Base.show(io, q) s1 = String(take!(io)) - SymbolicUtils.show_term(io, SymbolicUtils.Term(identity, [q])) + SymbolicUtils.show_term(io, SymbolicUtils._Term(identity, [q])) s2 = String(take!(io)) @test s1 == s2 end diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index ae9d9a213..aa0a51286 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -207,7 +207,7 @@ function gen_expr(lvl=5) n = rand(1:5) args = [gen_expr(lvl-1) for i in 1:n] - Term{Number}(f, first.(args)), f(last.(args)...) + _Term(Number, f, first.(args)), f(last.(args)...) else f = rand((-,/)) l = gen_expr(lvl-1) @@ -217,7 +217,7 @@ function gen_expr(lvl=5) end args = [l, r] - Term{Number}(f, first.(args)), f(last.(args)...) + _Term(Number, f, first.(args)), f(last.(args)...) end end diff --git a/test/order.jl b/test/order.jl index 3b7095e95..0c5a32406 100644 --- a/test/order.jl +++ b/test/order.jl @@ -27,8 +27,8 @@ end @test istotal(b*a, a) @test istotal(a, b*a) @test !(b*a <ₑ b+a) -@test Term(^, [1,-1]) <ₑ a -@test istotal(a, Term(^, [1,-1])) +@test _Term(^, [1, -1]) <ₑ a +@test istotal(a, _Term(^, [1, -1])) @testset "operator order" begin fs = (*, -, +) @@ -77,7 +77,7 @@ end @testset "small terms" begin # this failing was a cause of a nasty stackoverflow #82 @syms a - istotal(Term(^, [a, -1]), (a + 2)) + istotal(_Term(^, [a, -1]), (a + 2)) end @testset "transitivity" begin diff --git a/test/polyform.jl b/test/polyform.jl index 03c158c9d..f7357a017 100644 --- a/test/polyform.jl +++ b/test/polyform.jl @@ -36,7 +36,7 @@ end @syms A::Vector{Real} # test that the following works - expand(Term{Real}(getindex, [A, 3]) - 3) + expand(_Term(Real, getindex, [A, 3]) - 3) end @testset "simplify_fractions with quick-cancel" begin diff --git a/test/rewrite.jl b/test/rewrite.jl index 3bb2621e3..243823379 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -38,9 +38,12 @@ end @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8) - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8]) - @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; type = Any)) == + (Symbolic[9], _Const(8)) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; type = Any)) == + (Symbolic[9, 8], _Const(9), Symbolic[9, 8]) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; type = Any)) == + (Symbolic[], _Const(6), Symbolic[]) end using SymbolicUtils: @capture diff --git a/test/rulesets.jl b/test/rulesets.jl index ddebedf28..e83aa4f5f 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -18,10 +18,10 @@ end @testset "Numeric" begin @syms a::Integer b c d x::Real y::Number - @eqtest simplify(Term{Real}(conj, [x])) == x - @eqtest simplify(Term{Real}(real, [x])) == x - @eqtest simplify(Term{Real}(imag, [x])) == 0 - @eqtest simplify(Term{Real}(imag, [y])) == imag(y) + @eqtest simplify(_Term(Real, conj, [x])) == x + @eqtest simplify(_Term(Real, real, [x])) == x + @eqtest simplify(_Term(Real, imag, [x])) == 0 + @eqtest simplify(_Term(Real, imag, [y])) == imag(y) @eqtest simplify(x - y) == x + -1 * y @eqtest simplify(x - sin(y)) == x + -1 * sin(y) @eqtest simplify(-sin(x)) == -1 * sin(x) @@ -44,14 +44,13 @@ end @eqtest simplify(a * b * 1 * c * d) == simplify(a * b * c * d) @eqtest simplify_fractions(x^2.0 / (x * y)^2.0) == simplify_fractions(1 / (y^2.0)) - @test simplify(Term(one, [a])) == 1 - @test simplify(Term(one, [b + 1])) == 1 - @test simplify(Term(one, [x + 2])) == 1 + @test simplify(_Term(one, [a])) == 1 + @test simplify(_Term(one, [b + 1])) == 1 + @test simplify(_Term(one, [x + 2])) == 1 - - @test simplify(Term(zero, [a])) == 0 - @test simplify(Term(zero, [b + 1])) == 0 - @test simplify(Term(zero, [x + 2])) == 0 + @test simplify(_Term(zero, [a])) == 0 + @test simplify(_Term(zero, [b + 1])) == 0 + @test simplify(_Term(zero, [x + 2])) == 0 end @testset "LiteralReal" begin @@ -77,8 +76,8 @@ end @eqtest simplify(true & (0 < a)) == (0 < a) @eqtest simplify(false & (0 < a)) == false @eqtest simplify((0 < a) & false) == false - @eqtest simplify(Term{Bool}(!, [true])) == false - @eqtest simplify(Term{Bool}(|, [false, true])) == true + @eqtest simplify(_Term(Bool, !, [true])) == false + @eqtest simplify(_Term(Bool, |, [false, true])) == true @eqtest simplify(ifelse(true, a, b)) == a @eqtest simplify(ifelse(false, a, b)) == b diff --git a/test/runtests.jl b/test/runtests.jl index 049313154..4159b7d87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,7 @@ include("utils.jl") if haskey(ENV, "SU_BENCHMARK_ONLY") include("benchmark.jl") else + include("types.jl") include("basics.jl") include("order.jl") include("polyform.jl") diff --git a/test/types.jl b/test/types.jl new file mode 100644 index 000000000..6bcd510ba --- /dev/null +++ b/test/types.jl @@ -0,0 +1,165 @@ +using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name, get_val + +@testset "Expronicon generated constructors" begin + s1 = Sym(:abc) + s2 = Sym(name = :def) + name = :ghi + s3 = Sym(; name) + bs1 = BasicSymbolic{Float64}(impl = s1) + impl = s2 + bs2 = BasicSymbolic{Int64}(; impl) + @testset "Sym" begin + @test typeof(s1) == BasicSymbolicImpl + @test_nowarn Sym(Symbol("")) + @test s1.name == :abc + @test typeof(s2.name) == Symbol + @test typeof(s1) == BasicSymbolicImpl + @test s2.name == :def + @test s3.name == :ghi + end + @testset "Term" begin + t1 = Term(sin, [bs1]) + @test typeof(t1) == BasicSymbolicImpl + @test t1.f == sin + @test isequal(t1.arguments, [bs1]) + @test typeof(t1.arguments) == Vector{Symbolic} + end + @testset "Div" begin + d1 = Div(num = bs1, den = bs2) + @test typeof(d1) == BasicSymbolicImpl + @test isequal(d1.num, bs1) + @test isequal(d1.den, bs2) + @test typeof(d1.simplified) == Base.RefValue{Bool} + @test isassigned(d1.simplified) + @test !d1.simplified[] + @test isequal(d1.arguments, [bs1, bs2]) + num = bs1 + den = bs2 + d2 = Div(; num, den) + @test isequal(d2.num, bs1) + @test isequal(d2.den, bs2) + end + @testset "Pow" begin + p1 = Pow(base = bs1, exp = bs2) + @test typeof(p1) == BasicSymbolicImpl + @test isequal(p1.base, bs1) + @test isequal(p1.exp, bs2) + @test isequal(p1.arguments, [bs1, bs2]) + base = bs1 + exp = bs2 + p2 = Pow(; base, exp) + @test isequal(p2.base, bs1) + @test isequal(p2.exp, bs2) + end + c1 = Const(1) + bc1 = BasicSymbolic{Int}(impl = c1) + c2 = Const(val = 3.14) + bc2 = BasicSymbolic{Float64}(impl = c2) + @testset "Const" begin + @test typeof(c1) == BasicSymbolicImpl + @test typeof(c1.val) == Int + @test c1.val == 1 + @test typeof(c2.val) == Float64 + @test c2.val == 3.14 + c3 = Const(big"123456789012345678901234567890") + @test typeof(c3.val) == BigInt + @test c3.val == big"123456789012345678901234567890" + c4 = Const(big"1.23456789012345678901") + @test typeof(c4.val) == BigFloat + @test c4.val == big"1.23456789012345678901" + end + coeff = bc1 + dict = Dict(bs1 => 3, bs2 => 5) + @testset "Add" begin + a1 = Add(; coeff, dict) + @test typeof(a1) == BasicSymbolicImpl + @test a1.coeff isa BasicSymbolic + @test isequal(a1.coeff, bc1) + @test typeof(a1.dict) == Dict{BasicSymbolic, Any} + @test a1.dict == dict + @test typeof(a1.arguments) == Vector{BasicSymbolic} + @test isempty(a1.arguments) + @test typeof(a1.issorted) == Base.RefValue{Bool} + @test !a1.issorted[] + end + @testset "Mul" begin + m1 = Mul(; coeff, dict) + @test typeof(m1) == BasicSymbolicImpl + @test m1.coeff isa BasicSymbolic + @test isequal(m1.coeff, bc1) + @test typeof(m1.dict) == Dict{BasicSymbolic, Any} + @test m1.dict == dict + @test typeof(m1.arguments) == Vector{BasicSymbolic} + @test isempty(m1.arguments) + @test typeof(m1.issorted) == Base.RefValue{Bool} + @test !m1.issorted[] + end + @testset "BasicSymbolic" begin + @test typeof(bs1) == BasicSymbolic{Float64} + @test bs1 isa BasicSymbolic + @test bs1 isa SymbolicUtils.Symbolic + @test bs1.metadata isa SymbolicUtils.Metadata + @test bs1.metadata == SymbolicUtils.NO_METADATA + @test typeof(bs1.hash) == Base.RefValue{UInt} + @test bs1.hash[] == SymbolicUtils.EMPTY_HASH + end +end + +@testset "Custom constructors" begin + @testset "Sym" begin + s1 = _Sym(Int64, :x) + s2 = _Sym(Float64, :y) + @test typeof(s1) == BasicSymbolic{Int64} + @test s1.metadata == SymbolicUtils.NO_METADATA + @test s1.hash[] == SymbolicUtils.EMPTY_HASH + @test get_name(s1) == :x + @test typeof(s2) == BasicSymbolic{Float64} + @test s2.metadata == SymbolicUtils.NO_METADATA + @test s2.hash[] == SymbolicUtils.EMPTY_HASH + @test get_name(s2) == :y + end + @testset "Term" begin + s1 = _Sym(Float64, :x) + s2 = _Sym(Float64, :y) + t = _Term(Float64, mod, [s1, s2]) + @test typeof(t) == BasicSymbolic{Float64} + @test t.metadata == SymbolicUtils.NO_METADATA + @test t.hash[] == SymbolicUtils.EMPTY_HASH + @test t.impl.f == mod + @test isequal(t.impl.arguments, [s1, s2]) + end + @testset "Const" begin + c1 = _Const(1.0) + @test typeof(c1) == BasicSymbolic{Float64} + @test c1.metadata == SymbolicUtils.NO_METADATA + @test c1.hash[] == SymbolicUtils.EMPTY_HASH + @test get_val(c1) == 1.0 + c2 = _Const(big"123456789012345678901234567890") + @test typeof(c2) == BasicSymbolic{BigInt} + @test c2.metadata == SymbolicUtils.NO_METADATA + @test c2.hash[] == SymbolicUtils.EMPTY_HASH + @test get_val(c2) == big"123456789012345678901234567890" + c3 = _Const(big"1.23456789012345678901") + @test typeof(c3) == BasicSymbolic{BigFloat} + @test c3.metadata == SymbolicUtils.NO_METADATA + @test c3.hash[] == SymbolicUtils.EMPTY_HASH + @test get_val(c3) == big"1.23456789012345678901" + end +end + +@testset "BasicSymbolic iszero" begin + c1 = _Const(0) + @test SymbolicUtils._iszero(c1) + c2 = _Const(1) + @test !SymbolicUtils._iszero(c2) + c3 = _Const(0.0) + @test SymbolicUtils._iszero(c3) + c4 = _Const(0.00000000000000000000000001) + @test !SymbolicUtils._iszero(c4) + c5 = _Const(big"326264532521352634435352152") + @test !SymbolicUtils._iszero(c5) + c6 = _Const(big"0.314654523452") + @test !SymbolicUtils._iszero(c6) + s = _Sym(Real, :y) + @test !SymbolicUtils._iszero(s) +end