From 898f18c6dac33cafd651374e0e609305deec0fbd Mon Sep 17 00:00:00 2001 From: a Date: Sat, 8 Jun 2024 14:19:58 +0100 Subject: [PATCH 01/23] adapt changes --- Project.toml | 2 +- src/SymbolicUtils.jl | 6 ++-- src/code.jl | 4 +-- src/interface.jl | 65 -------------------------------------------- src/polyform.jl | 13 ++++----- src/rule.jl | 2 +- src/substitute.jl | 1 - src/types.jl | 30 ++++---------------- 8 files changed, 16 insertions(+), 107 deletions(-) diff --git a/Project.toml b/Project.toml index 7712058c8..451b8dace 100644 --- a/Project.toml +++ b/Project.toml @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" -TermInterface = "0.4" +TermInterface = "1.0.1" TimerOutputs = "0.5" Unityper = "0.1.2" julia = "1.3" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index d748f46c6..27c0a7dd0 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -16,12 +16,10 @@ using SymbolicIndexingInterface import Base: +, -, *, /, //, \, ^, ImmutableDict using ConstructionBase using TermInterface -import TermInterface: iscall, isexpr, issym, symtype, head, children, +import TermInterface: iscall, isexpr, head, children, operation, arguments, metadata, maketerm -const istree = iscall -Base.@deprecate_binding istree iscall -export istree, operation, arguments, unsorted_arguments, similarterm, iscall +export operation, arguments, unsorted_arguments, iscall # Sym, Term, # Add, Mul and Pow include("types.jl") diff --git a/src/code.jl b/src/code.jl index 84512007a..b1a4db876 100644 --- a/src/code.jl +++ b/src/code.jl @@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x) if isterm(x) return term(operation(x), args...) else - return maketerm(typeof(x), operation(x), - args, symtype(x), - metadata(x)) + return maketerm(typeof(x), operation(x), args, metadata(x)) end else return x diff --git a/src/interface.jl b/src/interface.jl index 355137ecb..8cdf20e94 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,11 +1,3 @@ -""" - iscall(x) - -Returns `true` if `x` is a term. If true, `operation`, `arguments` -must also be defined for `x` appropriately. -""" -iscall(x) = false - """ symtype(x) @@ -25,60 +17,3 @@ Returns `true` if `x` is a symbol. If true, `nameof` must be defined on `x` and must return a Symbol. """ issym(x) = false - -""" - operation(x) - -If `x` is a term as defined by `iscall(x)`, `operation(x)` returns the -head of the term if `x` represents a function call, for example, the head -is the function being called. -""" -function operation end - -""" - arguments(x) - -Get the arguments of `x`, must be defined if `iscall(x)` is `true`. -""" -function arguments end - -""" - unsorted_arguments(x::T) - -If x is a term satisfying `iscall(x)` and your term type `T` provides -an optimized implementation for storing the arguments, this function can -be used to retrieve the arguments when the order of arguments does not matter -but the speed of the operation does. -""" -unsorted_arguments(x) = arguments(x) -arity(x) = length(unsorted_arguments(x)) - -""" - metadata(x) - -Return the metadata attached to `x`. -""" -metadata(x) = nothing - -""" - metadata(x, md) - -Returns a new term which has the structure of `x` but also has -the metadata `md` attached to it. -""" -function metadata(x, data) - error("Setting metadata on $x is not possible") -end - -""" - similarterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=:call) - -Returns a term that is in the same closure of types as `typeof(x)`, -with `head` as the head and `args` as the arguments, `type` as the symtype -and `metadata` as the metadata. By default this will execute `head(args...)`. -`x` parameter can also be a `Type`. The `exprhead` keyword argument is useful -when manipulating `Expr`s. - -`similarterm` is deprecated see help for `maketerm` instead. -""" -function similarterm end diff --git a/src/polyform.jl b/src/polyform.jl index 88019d5ce..8625a09df 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) maketerm(typeof(x), op, map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args), - symtype(x), metadata(x)) else x @@ -176,11 +175,10 @@ isexpr(x::PolyForm) = true iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true -function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata) - basicsymbolic(t, f, args, symtype, metadata) +function maketerm(t::Type{<:PolyForm}, f, args, metadata) + basicsymbolic(t, f, args, metadata) end -function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, - args, symtype, metadata) +function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata) f(args...) end @@ -252,7 +250,7 @@ function unpolyize(x) # we need a special makterm here because the default one used in Postwalk will call # promote_symtype to get the new type, but we just want to forward that in case # promote_symtype is not defined for some of the expressions here. - Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x) + Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, m))(x) end function toterm(x::PolyForm) @@ -305,6 +303,7 @@ function add_divs(x, y) end function frac_maketerm(T, f, args, stype, metadata) + # TODO add stype to T? if f in (*, /, \, +, -) f(args...) elseif f == (^) @@ -314,7 +313,7 @@ function frac_maketerm(T, f, args, stype, metadata) args[1]^args[2] end else - maketerm(T, f, args, stype, metadata) + maketerm(T, f, args, metadata) end end diff --git a/src/rule.jl b/src/rule.jl index 89b1242bd..9599b5a7d 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -408,7 +408,7 @@ function (acr::ACRule)(term) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result - return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term), metadata(term)) + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], metadata(term)) end end end diff --git a/src/substitute.jl b/src/substitute.jl index 99ac134a0..b3f7b668e 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -34,7 +34,6 @@ function substitute(expr, dict; fold=true) maketerm(typeof(expr), op, args, - symtype(expr), metadata(expr)) else expr diff --git a/src/types.jl b/src/types.jl index 3abf6c139..5a1c1e974 100644 --- a/src/types.jl +++ b/src/types.jl @@ -162,7 +162,7 @@ function unsorted_arguments(x::BasicSymbolic) if isadd(x) for (k, v) in x.dict push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v])) + maketerm(k, *, [k, v], nothing)) end else # MUL for (k, v) in x.dict @@ -535,10 +535,12 @@ end unflatten(t) = t -function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) - basicsymbolic(head, args, type, metadata) +function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) + basicsymbolic(head, args, symtype(T), metadata) end +symtype(T::Type{<:Symbolic{T}}) where T = T + function basicsymbolic(f, args, stype, metadata) if f isa Symbol @@ -635,28 +637,6 @@ function to_symbolic(x) x end -""" - similarterm(x, op, args, symtype=nothing; metadata=nothing) - -""" -function similarterm(x, op, args, symtype=nothing; metadata=nothing) - Base.depwarn("""`similarterm` is deprecated, use `maketerm` instead. - `similarterm(x, op, args, symtype; metadata)` is now - `maketerm(typeof(x), op, args, symtype, metadata)`""", :similarterm) - TermInterface.maketerm(typeof(x), op, args, symtype, metadata) -end - -# Old fallback -function similarterm(T::Type, op, args, symtype=nothing; metadata=nothing) - - Base.depwarn("`similarterm` is deprecated, use `maketerm` instead." * - "See https://github.com/JuliaSymbolics/TermInterface.jl for details.", :similarterm) - op(args...) -end - -export similarterm - - ### ### Pretty printing ### From 1ab085b4793e27997019f1bb309c1853063f409f Mon Sep 17 00:00:00 2001 From: a Date: Sat, 8 Jun 2024 15:07:17 +0100 Subject: [PATCH 02/23] update types --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 5a1c1e974..e757f1d71 100644 --- a/src/types.jl +++ b/src/types.jl @@ -539,7 +539,7 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) basicsymbolic(head, args, symtype(T), metadata) end -symtype(T::Type{<:Symbolic{T}}) where T = T +symtype(::Type{<:Symbolic{T}}) where T = T function basicsymbolic(f, args, stype, metadata) From a1b9fe0a7f9f02094a29feaafe9c5457d90ce826 Mon Sep 17 00:00:00 2001 From: a Date: Sat, 8 Jun 2024 22:08:39 +0100 Subject: [PATCH 03/23] remove symtype --- src/types.jl | 2 ++ src/utils.jl | 12 ++++++------ test/basics.jl | 12 ++++++------ 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/types.jl b/src/types.jl index e757f1d71..469634515 100644 --- a/src/types.jl +++ b/src/types.jl @@ -98,6 +98,7 @@ end ### ### Term interface ### +symtype(x) = typeof(x) symtype(x::Number) = typeof(x) @inline symtype(::Symbolic{T}) where T = T @@ -192,6 +193,7 @@ isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +issym(x) = false issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) isterm(x) = isa_SymType(Val(:Term), x) ismul(x) = isa_SymType(Val(:Mul), x) diff --git a/src/utils.jl b/src/utils.jl index 69b6e8e2d..812e229fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -53,7 +53,7 @@ function fold(t) # evaluate it return operation(t)(tt...) else - return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t)) + return maketerm(typeof(t), operation(t), tt, metadata(t)) end else return t @@ -147,19 +147,19 @@ function flatten_term(⋆, x) push!(flattened_args, t) end end - maketerm(typeof(x), ⋆, flattened_args, symtype(x), metadata(x)) + maketerm(typeof(x), ⋆, flattened_args, metadata(x)) end function sort_args(f, t) args = arguments(t) if length(args) < 2 - return maketerm(typeof(t), f, args, symtype(t), metadata(t)) + return maketerm(typeof(t), f, args, metadata(t)) elseif length(args) == 2 x, y = args - return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t)) + return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], metadata(t)) end args = args isa Tuple ? [args...] : args - maketerm(typeof(t), f, sort(args, lt=<ₑ), symtype(t), metadata(t)) + maketerm(typeof(t), f, sort(args, lt=<ₑ), metadata(t)) end # Linked List interface @@ -225,7 +225,7 @@ macro matchable(expr) SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)] Base.length(x::$name) = $(length(fields) + 1) - SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...) + SymbolicUtils.maketerm(x::$name, f, args, metadata) = f(args...) end |> esc end diff --git a/test/basics.jl b/test/basics.jl index 36228324c..66c2b950d 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -216,18 +216,18 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1)) - @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b) + @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype # and is consistent with BasicSymbolic arithmetic operations - @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c) - @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0) - @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3) + @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], nothing), (a / b) * c) + @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], nothing), 0) + @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], nothing), (a * b)^3) # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") - s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata) + s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], metadata) @test hasmetadata(s, Ctx1) @test getmetadata(s, Ctx1) == "meta_1" end From eaf53e2583397576a315b864e0c6fcfb5db4756d Mon Sep 17 00:00:00 2001 From: a Date: Sat, 8 Jun 2024 22:51:06 +0100 Subject: [PATCH 04/23] make tests pass --- src/code.jl | 4 ++-- src/polyform.jl | 4 ++-- src/rewriters.jl | 34 +++++++++------------------------- 3 files changed, 13 insertions(+), 29 deletions(-) diff --git a/src/code.jl b/src/code.jl index b1a4db876..81eaacba4 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, similarterm, unsorted_arguments, metadata, isterm, term + symtype, unsorted_arguments, metadata, isterm, term ##== state management ==## @@ -694,7 +694,7 @@ function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) args = map(Base.Fix1(_cse!, mem), arguments(expr)) - t = similarterm(expr, op, args) + t = maketerm(typeof(expr), op, args, nothing) v, dict = mem update! = let v=v, t=t diff --git a/src/polyform.jl b/src/polyform.jl index 8625a09df..a8ba96223 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -250,7 +250,7 @@ function unpolyize(x) # we need a special makterm here because the default one used in Postwalk will call # promote_symtype to get the new type, but we just want to forward that in case # promote_symtype is not defined for some of the expressions here. - Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, m))(x) + Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x) end function toterm(x::PolyForm) @@ -302,7 +302,7 @@ function add_divs(x, y) end end -function frac_maketerm(T, f, args, stype, metadata) +function frac_maketerm(T, f, args, metadata) # TODO add stype to T? if f in (*, /, \, +, -) f(args...) diff --git a/src/rewriters.jl b/src/rewriters.jl index 3b3bba5e5..73d199d81 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -167,11 +167,7 @@ end struct Walk{ord, C, F, threaded} rw::C thread_cutoff::Int - maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function - # that behaves like `similarterm` here, we use `compatmaker` to wrap - # maketerm-like input to do this, with a warning if similarterm provided - # we need this workaround to deprecate because similarterm takes value - # but maketerm only knows the type. + maketerm::F end function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} @@ -183,25 +179,13 @@ end using .Threads -function compatmaker(similarterm, maketerm) - # XXX: delete this and only use maketerm in a future release. - if similarterm isa Nothing - function (x, f, args, type=_promote_symtype(f, args); metadata) - maketerm(typeof(x), f, args, type, metadata) - end - else - Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm) - similarterm - end -end -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) + +function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) +function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end struct PassThrough{C} @@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} end if iscall(x) - x = p.maketerm(x, operation(x), map(PassThrough(p), - unsorted_arguments(x)), metadata=metadata(x)) + x = p.maketerm(typeof(x), operation(x), map(PassThrough(p), + unsorted_arguments(x)), metadata(x)) end return ord === :post ? p.rw(x) : x @@ -245,7 +229,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} end end args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.maketerm(x, operation(x), args, metadata=metadata(x)) + t = p.maketerm(typeof(x), operation(x), args, metadata(x)) end return ord === :post ? p.rw(t) : t else From 02fde6da3db4bf0c12d1ca929c4110acfaa08d1c Mon Sep 17 00:00:00 2001 From: a Date: Sun, 9 Jun 2024 12:30:26 +0100 Subject: [PATCH 05/23] update docs --- docs/src/manual/interface.md | 14 +++++++++++++- src/interface.jl | 19 ------------------- src/types.jl | 17 +++++++++++++++++ 3 files changed, 30 insertions(+), 20 deletions(-) delete mode 100644 src/interface.jl diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 899aff658..5a7a4a220 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -13,6 +13,18 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym ## SymbolicUtils.jl only methods -`promote_symtype(f, arg_symtypes...)` +### `symtype(x)` + +Returns the symbolic type of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules +specific to numbers (such as commutativity of multiplication). Or such +rules that may be implemented in the future. + +### `issym(x)` + +Returns `true` if `x` is a symbol. If true, `nameof` must be defined +on `x` and must return a Symbol. + +### `promote_symtype(f, arg_symtypes...)` Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`. diff --git a/src/interface.jl b/src/interface.jl deleted file mode 100644 index 8cdf20e94..000000000 --- a/src/interface.jl +++ /dev/null @@ -1,19 +0,0 @@ -""" - symtype(x) - -Returns the symbolic type of `x`. By default this is just `typeof(x)`. -Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules -specific to numbers (such as commutativity of multiplication). Or such -rules that may be implemented in the future. -""" -function symtype(x) - typeof(x) -end - -""" - issym(x) - -Returns `true` if `x` is a symbol. If true, `nameof` must be defined -on `x` and must return a Symbol. -""" -issym(x) = false diff --git a/src/types.jl b/src/types.jl index 469634515..9b6cb404b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -98,6 +98,15 @@ end ### ### Term interface ### + +""" + symtype(x) + +Returns the symbolic type of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules +specific to numbers (such as commutativity of multiplication). Or such +rules that may be implemented in the future. +""" symtype(x) = typeof(x) symtype(x::Number) = typeof(x) @inline symtype(::Symbolic{T}) where T = T @@ -193,8 +202,16 @@ isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false + +""" + issym(x) + +Returns `true` if `x` is a symbol. If true, `nameof` must be defined +on `x` and must return a Symbol. +""" issym(x) = false issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) + isterm(x) = isa_SymType(Val(:Term), x) ismul(x) = isa_SymType(Val(:Mul), x) isadd(x) = isa_SymType(Val(:Add), x) From c53654ba31de4f0ea5449b90b14c814d022dd6a6 Mon Sep 17 00:00:00 2001 From: a Date: Sun, 9 Jun 2024 12:38:21 +0100 Subject: [PATCH 06/23] adjust docs --- docs/src/manual/representation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/manual/representation.md b/docs/src/manual/representation.md index 997d33f3a..fea21bf1b 100644 --- a/docs/src/manual/representation.md +++ b/docs/src/manual/representation.md @@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres. -All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) +All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) ### Preliminary representation of arithmetic From ffbb4e001a7f6dfa5702afd8c46e41ad593cfe2e Mon Sep 17 00:00:00 2001 From: a Date: Tue, 11 Jun 2024 22:19:05 +0200 Subject: [PATCH 07/23] adjust from suggestion --- src/types.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 9b6cb404b..af04d8bfa 100644 --- a/src/types.jl +++ b/src/types.jl @@ -108,8 +108,8 @@ specific to numbers (such as commutativity of multiplication). Or such rules that may be implemented in the future. """ symtype(x) = typeof(x) -symtype(x::Number) = typeof(x) @inline symtype(::Symbolic{T}) where T = T +@inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer @inline function operation(x::BasicSymbolic) @@ -558,8 +558,6 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) basicsymbolic(head, args, symtype(T), metadata) end -symtype(::Type{<:Symbolic{T}}) where T = T - function basicsymbolic(f, args, stype, metadata) if f isa Symbol From 55d1e15471d792427a6191c50caa95b681d24b98 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:19:25 +0200 Subject: [PATCH 08/23] Update src/types.jl Co-authored-by: Bowen S. Zhu --- src/types.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index af04d8bfa..bf94ae3c2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -100,10 +100,10 @@ end ### """ - symtype(x) + symtype(x) -Returns the symbolic type of `x`. By default this is just `typeof(x)`. -Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules +Returns the numeric type of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules specific to numbers (such as commutativity of multiplication). Or such rules that may be implemented in the future. """ From 139b374dac5409dad632a6a191a98daa0a12c98e Mon Sep 17 00:00:00 2001 From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:19:35 +0200 Subject: [PATCH 09/23] Update src/types.jl Co-authored-by: Bowen S. Zhu --- src/types.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index bf94ae3c2..7a06f321c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -204,10 +204,10 @@ iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false """ - issym(x) + issym(x) -Returns `true` if `x` is a symbol. If true, `nameof` must be defined -on `x` and must return a Symbol. +Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined +on `x` and must return a `Symbol`. """ issym(x) = false issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) From d6da38f5a0379ed0930f09a5e3b0d32ea33f2bd9 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:19:50 +0200 Subject: [PATCH 10/23] Update docs/src/manual/interface.md Co-authored-by: Bowen S. Zhu --- docs/src/manual/interface.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 5a7a4a220..265d53074 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -22,8 +22,8 @@ rules that may be implemented in the future. ### `issym(x)` -Returns `true` if `x` is a symbol. If true, `nameof` must be defined -on `x` and must return a Symbol. +Returns `true` if `x` is a `Sym`. If `true`, `nameof` must be defined +on `x` and must return a `Symbol`. ### `promote_symtype(f, arg_symtypes...)` From 5c63aba5d86b7957157cb8a447e2b481359df51c Mon Sep 17 00:00:00 2001 From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:20:10 +0200 Subject: [PATCH 11/23] Update docs/src/manual/interface.md Co-authored-by: Bowen S. Zhu --- docs/src/manual/interface.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 265d53074..552bf7229 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -15,7 +15,9 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym ### `symtype(x)` -Returns the symbolic type of `x`. By default this is just `typeof(x)`. +Returns the +[numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) +of `x`. By default this is just `typeof(x)`. Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules specific to numbers (such as commutativity of multiplication). Or such rules that may be implemented in the future. From fc435a7b02e5babfa042665560ae26a6d1086161 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:20:39 +0200 Subject: [PATCH 12/23] Update docs/src/manual/interface.md Co-authored-by: Bowen S. Zhu --- docs/src/manual/interface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 552bf7229..98518fb52 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -18,7 +18,7 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) of `x`. By default this is just `typeof(x)`. -Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules +Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules specific to numbers (such as commutativity of multiplication). Or such rules that may be implemented in the future. From 7b72cf12d1f646b3cb2db0db11d5e749f3145542 Mon Sep 17 00:00:00 2001 From: a Date: Tue, 11 Jun 2024 22:44:20 +0200 Subject: [PATCH 13/23] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 451b8dace..5c1cdcccb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "2.0.2" +version = "2.1.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From 02b7fa71a1c9db33a652a20f32970e6ea657e4b8 Mon Sep 17 00:00:00 2001 From: a Date: Tue, 18 Jun 2024 13:09:58 +0200 Subject: [PATCH 14/23] use metatheory rules --- Project.toml | 1 + src/rule.jl | 41 ++++++++++++++++++++++------------------- src/simplify_rules.jl | 15 +++++++++++---- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 5c1cdcccb..85e312364 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" diff --git a/src/rule.jl b/src/rule.jl index 9599b5a7d..48164c04d 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -298,24 +298,27 @@ whether the predicate holds or not. _In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side of an expression. """ -macro rule(expr) - @assert expr.head == :call && expr.args[1] == :(=>) - lhs = expr.args[2] - rhs = rewrite_rhs(expr.args[3]) - keys = Symbol[] - lhs_term = makepattern(lhs, keys) - unique!(keys) - quote - $(__source__) - lhs_pattern = $(lhs_term) - Rule($(QuoteNode(expr)), - lhs_pattern, - matcher(lhs_pattern), - __MATCHES__ -> $(makeconsequent(rhs)), - rule_depth($lhs_term)) - end -end - +# macro rule(expr) +# @assert expr.head == :call && expr.args[1] == :(=>) +# lhs = expr.args[2] +# rhs = rewrite_rhs(expr.args[3]) +# keys = Symbol[] +# lhs_term = makepattern(lhs, keys) +# unique!(keys) +# quote +# $(__source__) +# lhs_pattern = $(lhs_term) +# Rule($(QuoteNode(expr)), +# lhs_pattern, +# matcher(lhs_pattern), +# __MATCHES__ -> $(makeconsequent(rhs)), +# rule_depth($lhs_term)) +# end +# end + +using Metatheory +using Metatheory: @rule +using TermInterface: isexpr """ @capture ex pattern @@ -394,7 +397,7 @@ function (acr::ACRule)(term) else f = operation(term) # Assume that the matcher was formed by closing over a term - if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. + if f != operation(r.left) # Maybe offer a fallback if m.term errors. return nothing end diff --git a/src/simplify_rules.jl b/src/simplify_rules.jl index a612036cb..26a37a73d 100644 --- a/src/simplify_rules.jl +++ b/src/simplify_rules.jl @@ -1,4 +1,6 @@ using .Rewriters +using Metatheory: @rule + """ is_operation(f) Returns a single argument anonymous function predicate, that returns `true` if and only if @@ -6,10 +8,15 @@ the argument to the predicate satisfies `iscall` and `operation(x) == f` """ is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f) +const isnotflatplus = isnotflat(+) +const isnotflattimes = isnotflat(*) +const needs_sorting_plus = needs_sorting(+) +const needs_sorting_times = needs_sorting(*) + let CANONICALIZE_PLUS = [ - @rule(~x::isnotflat(+) => flatten_term(+, ~x)) - @rule(~x::needs_sorting(+) => sort_args(+, ~x)) + @rule(~x::isnotflatplus => flatten_term(+, ~x)) + @rule(~x::needs_sorting_plus => sort_args(+, ~x)) @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b) @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...)) @@ -28,8 +35,8 @@ let ] CANONICALIZE_TIMES = [ - @rule(~x::isnotflat(*) => flatten_term(*, ~x)) - @rule(~x::needs_sorting(*) => sort_args(*, ~x)) + @rule(~x::isnotflattimes => flatten_term(*, ~x)) + @rule(~x::needs_sorting_times => sort_args(*, ~x)) @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) From e9ebd8f56fcc0ddfc7c2510cefd55581f125df9b Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 14:14:03 +0200 Subject: [PATCH 15/23] building --- Project.toml | 2 +- src/SymbolicUtils.jl | 2 +- src/code.jl | 20 ++++++++++---------- src/inspect.jl | 2 +- src/matchers.jl | 2 +- src/ordering.jl | 8 ++++---- src/polyform.jl | 16 ++++++++-------- src/rewriters.jl | 8 ++++---- src/rule.jl | 4 ++-- src/simplify.jl | 2 +- src/substitute.jl | 4 ++-- src/types.jl | 23 +++++++++++------------ src/utils.jl | 18 +++++++++--------- test/basics.jl | 4 ++-- test/rewrite.jl | 12 +++++++----- 15 files changed, 64 insertions(+), 63 deletions(-) diff --git a/Project.toml b/Project.toml index 5c1cdcccb..f15941122 100644 --- a/Project.toml +++ b/Project.toml @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" -TermInterface = "1.0.1" +TermInterface = "2.0" TimerOutputs = "0.5" Unityper = "0.1.2" julia = "1.3" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 27c0a7dd0..4bc0eb147 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -19,7 +19,7 @@ using TermInterface import TermInterface: iscall, isexpr, head, children, operation, arguments, metadata, maketerm -export operation, arguments, unsorted_arguments, iscall +export operation, arguments, sorted_arguments, iscall # Sym, Term, # Add, Mul and Pow include("types.jl") diff --git a/src/code.jl b/src/code.jl index 81eaacba4..72565b768 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, unsorted_arguments, metadata, isterm, term + symtype, sorted_arguments, metadata, isterm, term ##== state management ==## @@ -115,7 +115,7 @@ function function_to_expr(op, O, st) (get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing name = nameof(op) fun = GlobalRef(NaNMath, name) - args = map(Base.Fix2(toexpr, st), arguments(O)) + args = map(Base.Fix2(toexpr, st), sorted_arguments(O)) expr = Expr(:call, fun) append!(expr.args, args) return expr @@ -124,7 +124,7 @@ end function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) out = get(st.rewrites, O, nothing) out === nothing || return out - args = map(Base.Fix2(toexpr, st), arguments(O)) + args = map(Base.Fix2(toexpr, st), sorted_arguments(O)) if length(args) >= 3 && symtype(O) <: Number x, xs = Iterators.peel(args) foldl(xs, init=x) do a, b @@ -138,7 +138,7 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) end function function_to_expr(::typeof(^), O, st) - args = arguments(O) + args = sorted_arguments(O) if length(args) == 2 && args[2] isa Real && args[2] < 0 ex = args[1] if args[2] == -1 @@ -151,7 +151,7 @@ function function_to_expr(::typeof(^), O, st) end function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) - args = arguments(O) + args = sorted_arguments(O) :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st))) end @@ -183,7 +183,7 @@ function toexpr(O, st) return expr′ else !iscall(O) && return O - args = arguments(O) + args = sorted_arguments(O) return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...) end end @@ -693,7 +693,7 @@ end function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) - args = map(Base.Fix1(_cse!, mem), arguments(expr)) + args = map(Base.Fix1(_cse!, mem), sorted_arguments(expr)) t = maketerm(typeof(expr), op, args, nothing) v, dict = mem @@ -716,7 +716,7 @@ end function _cse(exprs::AbstractArray) letblock = cse(Term{Any}(tuple, vec(exprs))) - letblock.pairs, reshape(arguments(letblock.body), size(exprs)) + letblock.pairs, reshape(sorted_arguments(letblock.body), size(exprs)) end function cse(x::MakeArray) @@ -744,7 +744,7 @@ end function cse_state!(state, t) !iscall(t) && return t state[t] = Base.get(state, t, 0) + 1 - foreach(x->cse_state!(state, x), unsorted_arguments(t)) + foreach(x->cse_state!(state, x), arguments(t)) end function cse_block!(assignments, counter, names, name, state, x) @@ -759,7 +759,7 @@ function cse_block!(assignments, counter, names, name, state, x) return sym end elseif iscall(x) - args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x)) + args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x)) if isterm(x) return term(operation(x), args...) else diff --git a/src/inspect.jl b/src/inspect.jl index 42b0b1be5..a61f6f44c 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -27,7 +27,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) end function AbstractTrees.children(x::Symbolic) - iscall(x) ? arguments(x) : isexpr(x) ? children(x) : () + iscall(x) ? sorted_arguments(x) : isexpr(x) ? children(x) : () end """ diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..91a6c1990 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -85,7 +85,7 @@ function matcher(segment::Segment) end function term_matcher(term) - matchers = (matcher(operation(term)), map(matcher, arguments(term))...,) + matchers = (matcher(operation(term)), map(matcher, sorted_arguments(term))...,) function term_matcher(success, data, bindings) !islist(data) && return nothing diff --git a/src/ordering.jl b/src/ordering.jl index 3417f3f85..6e276412b 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -22,7 +22,7 @@ function get_degrees(expr) ((Symbol(expr),) => 1,) elseif iscall(expr) op = operation(expr) - args = arguments(expr) + args = sorted_arguments(expr) if operation(expr) == (^) && args[2] isa Number return map(get_degrees(args[1])) do (base, pow) (base => pow * args[2]) @@ -35,7 +35,7 @@ function get_degrees(expr) _, idx = findmax(x->sum(last.(x), init=0), ds) return ds[idx] elseif operation(expr) == (getindex) - args = arguments(expr) + args = sorted_arguments(expr) return ((Symbol.(args)...,) => 1,) else return ((Symbol("zzzzzzz", hash(expr)),) => 1,) @@ -62,7 +62,7 @@ function lexlt(degs1, degs2) return false # they are equal end -_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0 +_arglen(a) = iscall(a) ? length(arguments(a)) : 0 function <ₑ(a::Tuple, b::Tuple) for (x, y) in zip(a, b) @@ -81,7 +81,7 @@ function <ₑ(a::BasicSymbolic, b::BasicSymbolic) bw = monomial_lt(db, da) if fw === bw && !isequal(a, b) if _arglen(a) == _arglen(b) - return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,) + return (operation(a), sorted_arguments(a)...,) <ₑ (operation(b), sorted_arguments(b)...,) else return _arglen(a) < _arglen(b) end diff --git a/src/polyform.jl b/src/polyform.jl index a8ba96223..60f59f0e5 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -103,7 +103,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) end op = operation(x) - args = arguments(x) + args = sorted_arguments(x) local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse) @@ -343,7 +343,7 @@ end function add_with_div(x, flatten=true) (!iscall(x) || operation(x) != (+)) && return x - aa = unsorted_arguments(x) + aa = arguments(x) !any(a->isdiv(a), aa) && return x # no rewrite necessary divs = filter(a->isdiv(a), aa) @@ -381,16 +381,16 @@ end function needs_div_rules(x) (isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) || - (iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) || - (iscall(x) && any(needs_div_rules, unsorted_arguments(x))) + (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) || + (iscall(x) && any(needs_div_rules, arguments(x))) end function has_div(x) - return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x))) + return isdiv(x) || (iscall(x) && any(has_div, arguments(x))) end flatten_pows(xs) = map(xs) do x - ispow(x) ? Iterators.repeated(arguments(x)...) : (x,) + ispow(x) ? Iterators.repeated(sorted_arguments(x)...) : (x,) end |> Iterators.flatten |> a->collect(Any,a) coefftype(x::PolyForm) = coefftype(x.p) @@ -414,8 +414,8 @@ Has optimized processes for `Mul` and `Pow` terms. function quick_cancel(d) if ispow(d) && isdiv(d.base) return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp)) - elseif ismul(d) && any(isdiv, unsorted_arguments(d)) - return prod(unsorted_arguments(d)) + elseif ismul(d) && any(isdiv, arguments(d)) + return prod(arguments(d)) elseif isdiv(d) num, den = quick_cancel(d.num, d.den) return Div(num, den) diff --git a/src/rewriters.jl b/src/rewriters.jl index 73d199d81..f9c9b603a 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -33,7 +33,7 @@ module Rewriters using SymbolicUtils: @timer using TermInterface -import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype +import SymbolicUtils: iscall, operation, arguments, sorted_arguments, metadata, node_count, _promote_symtype export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough # Cache of printed rules to speed up @timer @@ -205,7 +205,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} if iscall(x) x = p.maketerm(typeof(x), operation(x), map(PassThrough(p), - unsorted_arguments(x)), metadata(x)) + arguments(x)), metadata(x)) end return ord === :post ? p.rw(x) : x @@ -221,14 +221,14 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} x = p.rw(x) end if iscall(x) - _args = map(arguments(x)) do arg + _args = map(sorted_arguments(x)) do arg if node_count(arg) > p.thread_cutoff Threads.@spawn p(arg) else p(arg) end end - args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) + args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, sorted_arguments(x)) t = p.maketerm(typeof(x), operation(x), args, metadata(x)) end return ord === :post ? p.rw(t) : t diff --git a/src/rule.jl b/src/rule.jl index 9599b5a7d..8025cf819 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -121,7 +121,7 @@ getdepth(r::Rule) = r.depth function rule_depth(rule, d=0, maxdepth=0) if iscall(rule) - maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1) + maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in sorted_arguments(rule)), init=1) elseif rule isa Slot || rule isa Segment maxdepth = max(d, maxdepth) end @@ -399,7 +399,7 @@ function (acr::ACRule)(term) end T = symtype(term) - args = unsorted_arguments(term) + args = arguments(term) itr = acr.sets(eachindex(args), acr.arity) diff --git a/src/simplify.jl b/src/simplify.jl index 87bc95954..68fe78f83 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -45,6 +45,6 @@ end has_operation(x, op) = (iscall(x) && (operation(x) == op || any(a->has_operation(a, op), - unsorted_arguments(x)))) + arguments(x)))) Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...) diff --git a/src/substitute.jl b/src/substitute.jl index b3f7b668e..8fc980c69 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -20,7 +20,7 @@ function substitute(expr, dict; fold=true) op = substitute(operation(expr), dict; fold=fold) if fold canfold = !(op isa Symbolic) - args = map(unsorted_arguments(expr)) do x + args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) canfold = canfold && !(x′ isa Symbolic) x′ @@ -28,7 +28,7 @@ function substitute(expr, dict; fold=true) canfold && return op(args...) args else - args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr)) + args = map(x->substitute(x, dict, fold=fold), arguments(expr)) end maketerm(typeof(expr), diff --git a/src/types.jl b/src/types.jl index 7a06f321c..a45287623 100644 --- a/src/types.jl +++ b/src/types.jl @@ -126,8 +126,8 @@ end @inline head(x::BasicSymbolic) = operation(x) -function arguments(x::BasicSymbolic) - args = unsorted_arguments(x) +function sorted_arguments(x::BasicSymbolic) + args = arguments(x) @compactified x::BasicSymbolic begin Add => @goto ADD Mul => @goto MUL @@ -148,9 +148,8 @@ function arguments(x::BasicSymbolic) return args end -unsorted_arguments(x) = arguments(x) children(x::BasicSymbolic) = arguments(x) -function unsorted_arguments(x::BasicSymbolic) +function arguments(x::BasicSymbolic) @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL @@ -254,8 +253,8 @@ function _isequal(a, b, E) elseif E === POW isequal(a.exp, b.exp) && isequal(a.base, b.base) elseif E === TERM - a1 = arguments(a) - a2 = arguments(b) + a1 = sorted_arguments(a) + a2 = sorted_arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) else error_on_type() @@ -296,7 +295,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt !iszero(h) && return h op = operation(s) oph = op isa Function ? nameof(op) : op - h′ = hashvec(arguments(s), hash(oph, salt)) + h′ = hashvec(sorted_arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ else @@ -426,7 +425,7 @@ end @inline function numerators(x) isdiv(x) && return numerators(x.num) - iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] + iscall(x) && operation(x) === (*) ? sorted_arguments(x) : Any[x] end @inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] @@ -545,7 +544,7 @@ function unflatten(t::Symbolic{T}) where{T} if iscall(t) f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops - a = arguments(t) + a = sorted_arguments(t) return foldl((x,y) -> Term{T}(f, Any[x, y]), a) end end @@ -662,7 +661,7 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) if iscall(t) && operation(t) === (*) - coeff = first(arguments(t)) + coeff = first(sorted_arguments(t)) return isnegative(coeff) end return false @@ -694,7 +693,7 @@ end function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) - args = arguments(t) + args = sorted_arguments(t) @assert args[1] < 0 Any[-args[1], args[2:end]...] end @@ -806,7 +805,7 @@ function show_term(io::IO, t) end f = operation(t) - args = arguments(t) + args = sorted_arguments(t) if symtype(t) <: LiteralReal show_call(io, f, args) elseif f === (+) diff --git a/src/utils.jl b/src/utils.jl index 812e229fb..fb5ceaa36 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,7 +48,7 @@ end function fold(t) if iscall(t) - tt = map(fold, arguments(t)) + tt = map(fold, sorted_arguments(t)) if !any(x->x isa Symbolic, tt) # evaluate it return operation(t)(tt...) @@ -74,12 +74,12 @@ _isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x) _isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real) issortedₑ(args) = issorted(args, lt=<ₑ) -needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(arguments(x)) +needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(sorted_arguments(x)) # are there nested ⋆ terms? function isnotflat(⋆) function (x) - args = arguments(x) + args = sorted_arguments(x) for t in args if iscall(t) && operation(t) === (⋆) return true @@ -137,12 +137,12 @@ x + 2y ``` """ function flatten_term(⋆, x) - args = arguments(x) + args = sorted_arguments(x) # flatten nested ⋆ flattened_args = [] for t in args if iscall(t) && operation(t) === (⋆) - append!(flattened_args, arguments(t)) + append!(flattened_args, sorted_arguments(t)) else push!(flattened_args, t) end @@ -151,7 +151,7 @@ function flatten_term(⋆, x) end function sort_args(f, t) - args = arguments(t) + args = sorted_arguments(t) if length(args) < 2 return maketerm(typeof(t), f, args, metadata(t)) elseif length(args) == 2 @@ -182,12 +182,12 @@ Base.length(l::LL) = length(l.v)-l.i+1 Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY Base.isempty(t::Term) = false @inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +@inline cdr(t::Term) = sorted_arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) if iscall(v) - arguments(v) + sorted_arguments(v) else islist(v) ? LL(v, 2) : error("asked cdr of empty") end @@ -200,7 +200,7 @@ end if n === 0 return ll else - iscall(ll) ? drop_n(arguments(ll), n-1) : drop_n(cdr(ll), n-1) + iscall(ll) ? drop_n(sorted_arguments(ll), n-1) : drop_n(cdr(ll), n-1) end end @inline drop_n(ll::Union{Tuple, AbstractArray}, n) = drop_n(LL(ll, 1), n) diff --git a/test/basics.jl b/test/basics.jl index 66c2b950d..58c3cab0d 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -232,7 +232,7 @@ end @test getmetadata(s, Ctx1) == "meta_1" end -toterm(t) = Term{symtype(t)}(operation(t), arguments(t)) +toterm(t) = Term{symtype(t)}(operation(t), sorted_arguments(t)) @testset "diffs" begin @syms a b c @@ -279,7 +279,7 @@ end T = FnType{Tuple{T,S,Int} where {T,S}, Real} s = Sym{T}(:t) @syms a b c::Int - @test isequal(arguments(s(a, b, c)), [a, b, c]) + @test isequal(sorted_arguments(s(a, b, c)), [a, b, c]) end @testset "div" begin diff --git a/test/rewrite.jl b/test/rewrite.jl index 3bb2621e3..ccc754141 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,7 @@ @syms a b c +using Metatheory + @testset "Equality" begin @eqtest a == a @eqtest a != b @@ -65,7 +67,7 @@ using SymbolicUtils: @capture ret = @capture (a + b) (+)(~~z) @test ret @test @isdefined z - @test all(z .=== arguments(a + b)) + @test all(z .=== sorted_arguments(a + b)) #a more typical way to use the @capture macro @@ -84,24 +86,24 @@ end ex1 = ex + c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex + b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a * b ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * c @test SymbolicUtils.isterm(ex1) - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata ex = a ex = setmetadata(ex, MetaData, :metadata) ex1 = ex * b - @test getmetadata(arguments(ex1)[1], MetaData) == :metadata + @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata end \ No newline at end of file From e212a2fa837261035278f6674fad6b08ae596484 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 14:17:37 +0200 Subject: [PATCH 16/23] adjust some tests --- src/SymbolicUtils.jl | 2 +- src/polyform.jl | 5 +++-- src/types.jl | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 4bc0eb147..4bbb938e0 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -17,7 +17,7 @@ import Base: +, -, *, /, //, \, ^, ImmutableDict using ConstructionBase using TermInterface import TermInterface: iscall, isexpr, head, children, - operation, arguments, metadata, maketerm + operation, arguments, metadata, maketerm, sorted_arguments export operation, arguments, sorted_arguments, iscall # Sym, Term, diff --git a/src/polyform.jl b/src/polyform.jl index 60f59f0e5..235b46393 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -185,7 +185,8 @@ end head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) -function arguments(x::PolyForm{T}) where {T} +TermInterface.sorted_arguments(x::PolyForm{T}) = arguments(t) +function TermInterface.arguments(x::PolyForm{T}) where {T} function is_var(v) MP.nterms(v) == 1 && @@ -229,7 +230,7 @@ function arguments(x::PolyForm{T}) where {T} PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] end end -children(x::PolyForm) = [operation(x); arguments(x)] +children(x::PolyForm) = arguments(x) Base.show(io::IO, x::PolyForm) = show_term(io, x) diff --git a/src/types.jl b/src/types.jl index a45287623..4e0ddb7f3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -126,7 +126,7 @@ end @inline head(x::BasicSymbolic) = operation(x) -function sorted_arguments(x::BasicSymbolic) +function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) @compactified x::BasicSymbolic begin Add => @goto ADD @@ -148,8 +148,8 @@ function sorted_arguments(x::BasicSymbolic) return args end -children(x::BasicSymbolic) = arguments(x) -function arguments(x::BasicSymbolic) +TermInterface.children(x::BasicSymbolic) = arguments(x) +function TermInterface.arguments(x::BasicSymbolic) @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL From 6ba21a180d09f88d04be0a346351878908418b01 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 14:18:08 +0200 Subject: [PATCH 17/23] remove extra method --- src/polyform.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/polyform.jl b/src/polyform.jl index 235b46393..89736b1ec 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -185,7 +185,6 @@ end head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) -TermInterface.sorted_arguments(x::PolyForm{T}) = arguments(t) function TermInterface.arguments(x::PolyForm{T}) where {T} function is_var(v) From dc33ea7b821e75b6c5f801a35cd5ace131563dad Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 14:27:29 +0200 Subject: [PATCH 18/23] make tests green --- test/rewrite.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index ccc754141..8f4304ace 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,7 +1,5 @@ @syms a b c -using Metatheory - @testset "Equality" begin @eqtest a == a @eqtest a != b From 731d4e8de039a0c7ac6050e6532e6943b6c767a4 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 17:01:20 +0200 Subject: [PATCH 19/23] adapt maketerm for promote_symtype --- src/polyform.jl | 3 ++- src/types.jl | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index 89736b1ec..9450da554 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -176,7 +176,8 @@ iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true function maketerm(t::Type{<:PolyForm}, f, args, metadata) - basicsymbolic(t, f, args, metadata) + # TODO: this looks uncovered. + basicsymbolic(f, args, nothing, metadata) end function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata) f(args...) diff --git a/src/types.jl b/src/types.jl index 4e0ddb7f3..ea375a725 100644 --- a/src/types.jl +++ b/src/types.jl @@ -554,7 +554,15 @@ end unflatten(t) = t function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) - basicsymbolic(head, args, symtype(T), metadata) + st = symtype(T) + pst = _promote_symtype(head, args) + # Use promoted symtype only if not a subtype of the existing symtype of T. + # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # Where the result would have a symtype of Bool. + # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 + # TODO this should be optimized. + new_st = (pst === Any || pst <: st) ? st : pst + basicsymbolic(head, args, new_st, metadata) end From db0ae467dc1853723ff76056ad2b249abb3faa3c Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Mon, 24 Jun 2024 17:44:35 +0200 Subject: [PATCH 20/23] hacky fix to promote symtype --- src/types.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index ea375a725..b88304415 100644 --- a/src/types.jl +++ b/src/types.jl @@ -561,7 +561,13 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) # Where the result would have a symtype of Bool. # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 # TODO this should be optimized. - new_st = (pst === Any || pst <: st) ? st : pst + new_st = if pst === Bool + pst + elseif pst === Any || (st === Number && pst <: st) + st + else + pst + end basicsymbolic(head, args, new_st, metadata) end From c619ea9400c86155d7a1b9a324fd3a69c04d81a3 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Wed, 26 Jun 2024 12:19:57 +0200 Subject: [PATCH 21/23] iscall --- src/SymbolicUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 5b4fe150a..ef9b81f33 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -19,7 +19,7 @@ using TermInterface import TermInterface: iscall, isexpr, head, children, operation, arguments, metadata, maketerm, sorted_arguments -const istree = iscalls +const istree = iscall Base.@deprecate_binding istree iscall export istree, operation, arguments, sorted_arguments, similarterm, iscall # Sym, Term, From e338734b776259d661f23a5395b5775df09599d9 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Wed, 26 Jun 2024 13:33:06 +0200 Subject: [PATCH 22/23] add egraphs example --- test/egraphs.jl | 197 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 test/egraphs.jl diff --git a/test/egraphs.jl b/test/egraphs.jl new file mode 100644 index 000000000..a5801cdaa --- /dev/null +++ b/test/egraphs.jl @@ -0,0 +1,197 @@ +using Metatheory +using SymbolicUtils +const SU = SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm, Term +using SymbolicUtils: monadic, diadic +using InteractiveUtils + +EGraphs.preprocess(t::Symbolic) = toterm(unflatten(t)) + +""" +Equational rewrite rules for optimizing expressions +""" +opt_theory = @theory a b c x y z begin + a + (b + c) == (a + b) + c + a * (b * c) == (a * b) * c + x + 0 --> x + a + b == b + a + a - a => 0 # is it ok? + + 0 - x --> -x + + a * b == b * a + a * x + a * y == a*(x+y) + -1 * a --> -a + a + (-1 * b) == a - b + x * 1 --> x + x * 0 --> 0 + x/x --> 1 + # fraction rules + x^-1 == 1/x + 1/x * a == a/x # is this needed? + x / (x / y) --> y + x * (y / z) == (x * y) / z + (a/b) + (c/b) --> (a+c)/b + (a / b) / c == a/(b*c) + + # TODO prohibited rule + x / x --> 1 + + # pow rules + a * a == a^2 + (a^b)^c == a^(b*c) + a^b * a^c == a^(b+c) + a^b / a^c == a^(b-c) + (a*b)^c == a^c * b^c + + # logarithmic rules + # TODO variables are non-zero + log(x::Number) => log(x) + log(x * y) == log(x) + log(y) + log(x / y) == log(x) - log(y) + log(x^y) == y * log(x) + x^(log(y)) == y^(log(x)) + + # trig functions + sin(x)/cos(x) == tan(x) + cos(x)/sin(x) == cot(x) + sin(x)^2 + cos(x)^2 --> 1 + sin(2a) == 2sin(a)cos(a) + + sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y) + # hyperbolic trigonometric + # are these optimizing at all? dont think so + # sinh(x) == (ℯ^x - ℯ^(-x))/2 + # csch(x) == 1/sinh(x) + # cosh(x) == (ℯ^x + ℯ^(-x))/2 + # sech(x) == 1/cosh(x) + # sech(x) == 2/(ℯ^x + ℯ^(-x)) + # tanh(x) == sinh(x)/cosh(x) + # tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) + # coth(x) == 1/tanh(x) + # coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) + + # cosh(x)^2 - sinh(x)^2 --> 1 + # tanh(x)^2 + sech(x)^2 --> 1 + # coth(x)^2 - csch(x)^2 --> 1 + + # asinh(z) == log(z + √(z^2 + 1)) + # acosh(z) == log(z + √(z^2 - 1)) + # atanh(z) == log((1+z)/(1-z))/2 + # acsch(z) == log((1+√(1+z^2)) / z ) + # asech(z) == log((1 + √(1-z^2)) / z ) + # acoth(z) == log( (z+1)/(z-1) )/2 + + # folding + x::Number * y::Number => x*y + x::Number + y::Number => x+y + x::Number / y::Number => x/y + x::Number - y::Number => x-y +end +# opt_theory = @theory a b c x y begin +# a * x == x * a +# a * x + a * y == a*(x+y) +# -1 * a == -a +# a + -b --> a - b +# -b + a --> b - a +# end + + +# See +# * https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/ +# * https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/ +# * https://github.com/triscale-innov/GFlops.jl +# Measure the cost of expressions in terms of number of ASM instructions + +const op_costs = Dict() + +const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] + +const io = IOBuffer() + +for f in vcat(monadic, [-]) + z = get!(op_costs, nameof(f), Dict()) + for (t, at) in types + try + InteractiveUtils.code_native(io, f, (t,)) + catch e + z[(t,)] = z[(at,)] = 1 + continue + end + str = String(take!(io)) + z[(t,)] = z[(at,)] = length(split(str, "\n")) + end +end + +for f in vcat(diadic, [+, -, *, /, //, ^]) + z = get!(op_costs, nameof(f), Dict()) + for (t1, at1) in types, (t2, at2) in types + try + InteractiveUtils.code_native(io, f, (t1, t2)) + catch e + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 + continue + end + str = String(take!(io)) + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) + end +end + +function getopcost(f::Function, types::Tuple) + sym = nameof(f) + if haskey(op_costs, sym) && haskey(op_costs[sym], types) + return op_costs[sym][types] + end + + # print("$f $types | ") + io = IOBuffer() + try + InteractiveUtils.code_native(io, f, types) + catch e + op_costs[sym][types] = 1 + return 1 + end + str = String(take!(io)) + c = length(split(str, "\n")) + !haskey(op_costs, sym) && (op_costs[sym] = Dict()) + op_costs[sym][types] = c +end + +getopcost(f, types::Tuple) = get(get(op_costs, f, Dict()), types, 1) + +function costfun(n::VecExpr, op, children_costs::Vector{Float64}) + v_isexpr(n) || return 1 + # types = Tuple(map(x -> getdata(g[x], SymtypeAnalysis, Real), args)) + types = Tuple([Float64 for i in 1:v_arity(n)]) + opc = getopcost(op, types) + opc + sum(children_costs) +end + +denoisescalars(x, atol=1e-11) = Postwalk(Chain([ + # 0 - x --> -x + @acrule *(~x::Real, sin(~y)) => 0 where isapprox(x, 0; atol=atol) + @acrule *(~x::Real, cos(~y)) => 0 where isapprox(x, 0; atol=atol) + @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) + @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) +]))(x) + +function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...) + # ex = simplify(denoisescalars(ex, atol)) + # println(ex) + # readline() + + g = EGraph{BasicSymbolic}(ex) + + # display(g.classes);println(); + + report = saturate!(g, opt_theory, params) + verbose && @info report + extr = extract!(g, costfun) + return extr +end + +@syms x y z + +t = Term(+, [Term(*, [z, x]), Term(*, [z, y])]) + +optimize(t) \ No newline at end of file From 0c2a4505c61f34c2a18bf21089dc215115025c94 Mon Sep 17 00:00:00 2001 From: Alessandro Cheli Date: Thu, 27 Jun 2024 09:47:12 +0200 Subject: [PATCH 23/23] adjust e-graph integration --- test/egraphs.jl | 109 ++++++++++++++++++++++++------------------------ test/rewrite.jl | 2 + 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/test/egraphs.jl b/test/egraphs.jl index a5801cdaa..d139d316c 100644 --- a/test/egraphs.jl +++ b/test/egraphs.jl @@ -61,26 +61,26 @@ opt_theory = @theory a b c x y z begin sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y) # hyperbolic trigonometric # are these optimizing at all? dont think so - # sinh(x) == (ℯ^x - ℯ^(-x))/2 - # csch(x) == 1/sinh(x) - # cosh(x) == (ℯ^x + ℯ^(-x))/2 - # sech(x) == 1/cosh(x) - # sech(x) == 2/(ℯ^x + ℯ^(-x)) - # tanh(x) == sinh(x)/cosh(x) - # tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) - # coth(x) == 1/tanh(x) - # coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) - - # cosh(x)^2 - sinh(x)^2 --> 1 - # tanh(x)^2 + sech(x)^2 --> 1 - # coth(x)^2 - csch(x)^2 --> 1 - - # asinh(z) == log(z + √(z^2 + 1)) - # acosh(z) == log(z + √(z^2 - 1)) - # atanh(z) == log((1+z)/(1-z))/2 - # acsch(z) == log((1+√(1+z^2)) / z ) - # asech(z) == log((1 + √(1-z^2)) / z ) - # acoth(z) == log( (z+1)/(z-1) )/2 + sinh(x) == (ℯ^x - ℯ^(-x))/2 + csch(x) == 1/sinh(x) + cosh(x) == (ℯ^x + ℯ^(-x))/2 + sech(x) == 1/cosh(x) + sech(x) == 2/(ℯ^x + ℯ^(-x)) + tanh(x) == sinh(x)/cosh(x) + tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x)) + coth(x) == 1/tanh(x) + coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x)) + + cosh(x)^2 - sinh(x)^2 --> 1 + tanh(x)^2 + sech(x)^2 --> 1 + coth(x)^2 - csch(x)^2 --> 1 + + asinh(z) == log(z + √(z^2 + 1)) + acosh(z) == log(z + √(z^2 - 1)) + atanh(z) == log((1+z)/(1-z))/2 + acsch(z) == log((1+√(1+z^2)) / z ) + asech(z) == log((1 + √(1-z^2)) / z ) + acoth(z) == log( (z+1)/(z-1) )/2 # folding x::Number * y::Number => x*y @@ -88,13 +88,6 @@ opt_theory = @theory a b c x y z begin x::Number / y::Number => x/y x::Number - y::Number => x-y end -# opt_theory = @theory a b c x y begin -# a * x == x * a -# a * x + a * y == a*(x+y) -# -1 * a == -a -# a + -b --> a - b -# -b + a --> b - a -# end # See @@ -103,41 +96,46 @@ end # * https://github.com/triscale-innov/GFlops.jl # Measure the cost of expressions in terms of number of ASM instructions -const op_costs = Dict() -const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] - -const io = IOBuffer() - -for f in vcat(monadic, [-]) - z = get!(op_costs, nameof(f), Dict()) - for (t, at) in types - try - InteractiveUtils.code_native(io, f, (t,)) - catch e - z[(t,)] = z[(at,)] = 1 - continue +function make_op_costs() + const op_costs = Dict() + + const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)] + + const io = IOBuffer() + + for f in vcat(monadic, [-]) + z = get!(op_costs, nameof(f), Dict()) + for (t, at) in types + try + InteractiveUtils.code_native(io, f, (t,)) + catch e + z[(t,)] = z[(at,)] = 1 + continue + end + str = String(take!(io)) + z[(t,)] = z[(at,)] = length(split(str, "\n")) end - str = String(take!(io)) - z[(t,)] = z[(at,)] = length(split(str, "\n")) end -end - -for f in vcat(diadic, [+, -, *, /, //, ^]) - z = get!(op_costs, nameof(f), Dict()) - for (t1, at1) in types, (t2, at2) in types - try - InteractiveUtils.code_native(io, f, (t1, t2)) - catch e - z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 - continue + + for f in vcat(diadic, [+, -, *, /, //, ^]) + z = get!(op_costs, nameof(f), Dict()) + for (t1, at1) in types, (t2, at2) in types + try + InteractiveUtils.code_native(io, f, (t1, t2)) + catch e + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1 + continue + end + str = String(take!(io)) + z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) end - str = String(take!(io)) - z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n")) end + + op_costs end -function getopcost(f::Function, types::Tuple) +function getopcost(op_costs, f::Function, types::Tuple) sym = nameof(f) if haskey(op_costs, sym) && haskey(op_costs[sym], types) return op_costs[sym][types] @@ -175,6 +173,7 @@ denoisescalars(x, atol=1e-11) = Postwalk(Chain([ @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol) ]))(x) +const op_costs = make_op_costs() function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...) # ex = simplify(denoisescalars(ex, atol)) # println(ex) diff --git a/test/rewrite.jl b/test/rewrite.jl index 8f4304ace..ccc754141 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -1,5 +1,7 @@ @syms a b c +using Metatheory + @testset "Equality" begin @eqtest a == a @eqtest a != b