From 898f18c6dac33cafd651374e0e609305deec0fbd Mon Sep 17 00:00:00 2001
From: a
Date: Sat, 8 Jun 2024 14:19:58 +0100
Subject: [PATCH 01/23] adapt changes
---
Project.toml | 2 +-
src/SymbolicUtils.jl | 6 ++--
src/code.jl | 4 +--
src/interface.jl | 65 --------------------------------------------
src/polyform.jl | 13 ++++-----
src/rule.jl | 2 +-
src/substitute.jl | 1 -
src/types.jl | 30 ++++----------------
8 files changed, 16 insertions(+), 107 deletions(-)
diff --git a/Project.toml b/Project.toml
index 7712058c8..451b8dace 100644
--- a/Project.toml
+++ b/Project.toml
@@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
-TermInterface = "0.4"
+TermInterface = "1.0.1"
TimerOutputs = "0.5"
Unityper = "0.1.2"
julia = "1.3"
diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl
index d748f46c6..27c0a7dd0 100644
--- a/src/SymbolicUtils.jl
+++ b/src/SymbolicUtils.jl
@@ -16,12 +16,10 @@ using SymbolicIndexingInterface
import Base: +, -, *, /, //, \, ^, ImmutableDict
using ConstructionBase
using TermInterface
-import TermInterface: iscall, isexpr, issym, symtype, head, children,
+import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm
-const istree = iscall
-Base.@deprecate_binding istree iscall
-export istree, operation, arguments, unsorted_arguments, similarterm, iscall
+export operation, arguments, unsorted_arguments, iscall
# Sym, Term,
# Add, Mul and Pow
include("types.jl")
diff --git a/src/code.jl b/src/code.jl
index 84512007a..b1a4db876 100644
--- a/src/code.jl
+++ b/src/code.jl
@@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x)
if isterm(x)
return term(operation(x), args...)
else
- return maketerm(typeof(x), operation(x),
- args, symtype(x),
- metadata(x))
+ return maketerm(typeof(x), operation(x), args, metadata(x))
end
else
return x
diff --git a/src/interface.jl b/src/interface.jl
index 355137ecb..8cdf20e94 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -1,11 +1,3 @@
-"""
- iscall(x)
-
-Returns `true` if `x` is a term. If true, `operation`, `arguments`
-must also be defined for `x` appropriately.
-"""
-iscall(x) = false
-
"""
symtype(x)
@@ -25,60 +17,3 @@ Returns `true` if `x` is a symbol. If true, `nameof` must be defined
on `x` and must return a Symbol.
"""
issym(x) = false
-
-"""
- operation(x)
-
-If `x` is a term as defined by `iscall(x)`, `operation(x)` returns the
-head of the term if `x` represents a function call, for example, the head
-is the function being called.
-"""
-function operation end
-
-"""
- arguments(x)
-
-Get the arguments of `x`, must be defined if `iscall(x)` is `true`.
-"""
-function arguments end
-
-"""
- unsorted_arguments(x::T)
-
-If x is a term satisfying `iscall(x)` and your term type `T` provides
-an optimized implementation for storing the arguments, this function can
-be used to retrieve the arguments when the order of arguments does not matter
-but the speed of the operation does.
-"""
-unsorted_arguments(x) = arguments(x)
-arity(x) = length(unsorted_arguments(x))
-
-"""
- metadata(x)
-
-Return the metadata attached to `x`.
-"""
-metadata(x) = nothing
-
-"""
- metadata(x, md)
-
-Returns a new term which has the structure of `x` but also has
-the metadata `md` attached to it.
-"""
-function metadata(x, data)
- error("Setting metadata on $x is not possible")
-end
-
-"""
- similarterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=:call)
-
-Returns a term that is in the same closure of types as `typeof(x)`,
-with `head` as the head and `args` as the arguments, `type` as the symtype
-and `metadata` as the metadata. By default this will execute `head(args...)`.
-`x` parameter can also be a `Type`. The `exprhead` keyword argument is useful
-when manipulating `Expr`s.
-
-`similarterm` is deprecated see help for `maketerm` instead.
-"""
-function similarterm end
diff --git a/src/polyform.jl b/src/polyform.jl
index 88019d5ce..8625a09df 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
maketerm(typeof(x),
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args),
- symtype(x),
metadata(x))
else
x
@@ -176,11 +175,10 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true
-function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata)
- basicsymbolic(t, f, args, symtype, metadata)
+function maketerm(t::Type{<:PolyForm}, f, args, metadata)
+ basicsymbolic(t, f, args, metadata)
end
-function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
- args, symtype, metadata)
+function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata)
f(args...)
end
@@ -252,7 +250,7 @@ function unpolyize(x)
# we need a special makterm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
- Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x)
+ Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, m))(x)
end
function toterm(x::PolyForm)
@@ -305,6 +303,7 @@ function add_divs(x, y)
end
function frac_maketerm(T, f, args, stype, metadata)
+ # TODO add stype to T?
if f in (*, /, \, +, -)
f(args...)
elseif f == (^)
@@ -314,7 +313,7 @@ function frac_maketerm(T, f, args, stype, metadata)
args[1]^args[2]
end
else
- maketerm(T, f, args, stype, metadata)
+ maketerm(T, f, args, metadata)
end
end
diff --git a/src/rule.jl b/src/rule.jl
index 89b1242bd..9599b5a7d 100644
--- a/src/rule.jl
+++ b/src/rule.jl
@@ -408,7 +408,7 @@ function (acr::ACRule)(term)
if result !== nothing
# Assumption: inds are unique
length(args) == length(inds) && return result
- return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term), metadata(term))
+ return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], metadata(term))
end
end
end
diff --git a/src/substitute.jl b/src/substitute.jl
index 99ac134a0..b3f7b668e 100644
--- a/src/substitute.jl
+++ b/src/substitute.jl
@@ -34,7 +34,6 @@ function substitute(expr, dict; fold=true)
maketerm(typeof(expr),
op,
args,
- symtype(expr),
metadata(expr))
else
expr
diff --git a/src/types.jl b/src/types.jl
index 3abf6c139..5a1c1e974 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -162,7 +162,7 @@ function unsorted_arguments(x::BasicSymbolic)
if isadd(x)
for (k, v) in x.dict
push!(args, applicable(*,k,v) ? k*v :
- maketerm(k, *, [k, v]))
+ maketerm(k, *, [k, v], nothing))
end
else # MUL
for (k, v) in x.dict
@@ -535,10 +535,12 @@ end
unflatten(t) = t
-function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata)
- basicsymbolic(head, args, type, metadata)
+function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
+ basicsymbolic(head, args, symtype(T), metadata)
end
+symtype(T::Type{<:Symbolic{T}}) where T = T
+
function basicsymbolic(f, args, stype, metadata)
if f isa Symbol
@@ -635,28 +637,6 @@ function to_symbolic(x)
x
end
-"""
- similarterm(x, op, args, symtype=nothing; metadata=nothing)
-
-"""
-function similarterm(x, op, args, symtype=nothing; metadata=nothing)
- Base.depwarn("""`similarterm` is deprecated, use `maketerm` instead.
- `similarterm(x, op, args, symtype; metadata)` is now
- `maketerm(typeof(x), op, args, symtype, metadata)`""", :similarterm)
- TermInterface.maketerm(typeof(x), op, args, symtype, metadata)
-end
-
-# Old fallback
-function similarterm(T::Type, op, args, symtype=nothing; metadata=nothing)
-
- Base.depwarn("`similarterm` is deprecated, use `maketerm` instead." *
- "See https://github.com/JuliaSymbolics/TermInterface.jl for details.", :similarterm)
- op(args...)
-end
-
-export similarterm
-
-
###
### Pretty printing
###
From 1ab085b4793e27997019f1bb309c1853063f409f Mon Sep 17 00:00:00 2001
From: a
Date: Sat, 8 Jun 2024 15:07:17 +0100
Subject: [PATCH 02/23] update types
---
src/types.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/types.jl b/src/types.jl
index 5a1c1e974..e757f1d71 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -539,7 +539,7 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
basicsymbolic(head, args, symtype(T), metadata)
end
-symtype(T::Type{<:Symbolic{T}}) where T = T
+symtype(::Type{<:Symbolic{T}}) where T = T
function basicsymbolic(f, args, stype, metadata)
From a1b9fe0a7f9f02094a29feaafe9c5457d90ce826 Mon Sep 17 00:00:00 2001
From: a
Date: Sat, 8 Jun 2024 22:08:39 +0100
Subject: [PATCH 03/23] remove symtype
---
src/types.jl | 2 ++
src/utils.jl | 12 ++++++------
test/basics.jl | 12 ++++++------
3 files changed, 14 insertions(+), 12 deletions(-)
diff --git a/src/types.jl b/src/types.jl
index e757f1d71..469634515 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -98,6 +98,7 @@ end
###
### Term interface
###
+symtype(x) = typeof(x)
symtype(x::Number) = typeof(x)
@inline symtype(::Symbolic{T}) where T = T
@@ -192,6 +193,7 @@ isexpr(s::BasicSymbolic) = !issym(s)
iscall(s::BasicSymbolic) = isexpr(s)
@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false
+issym(x) = false
issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x)
isterm(x) = isa_SymType(Val(:Term), x)
ismul(x) = isa_SymType(Val(:Mul), x)
diff --git a/src/utils.jl b/src/utils.jl
index 69b6e8e2d..812e229fb 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -53,7 +53,7 @@ function fold(t)
# evaluate it
return operation(t)(tt...)
else
- return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t))
+ return maketerm(typeof(t), operation(t), tt, metadata(t))
end
else
return t
@@ -147,19 +147,19 @@ function flatten_term(⋆, x)
push!(flattened_args, t)
end
end
- maketerm(typeof(x), ⋆, flattened_args, symtype(x), metadata(x))
+ maketerm(typeof(x), ⋆, flattened_args, metadata(x))
end
function sort_args(f, t)
args = arguments(t)
if length(args) < 2
- return maketerm(typeof(t), f, args, symtype(t), metadata(t))
+ return maketerm(typeof(t), f, args, metadata(t))
elseif length(args) == 2
x, y = args
- return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t))
+ return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], metadata(t))
end
args = args isa Tuple ? [args...] : args
- maketerm(typeof(t), f, sort(args, lt=<ₑ), symtype(t), metadata(t))
+ maketerm(typeof(t), f, sort(args, lt=<ₑ), metadata(t))
end
# Linked List interface
@@ -225,7 +225,7 @@ macro matchable(expr)
SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)]
Base.length(x::$name) = $(length(fields) + 1)
- SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...)
+ SymbolicUtils.maketerm(x::$name, f, args, metadata) = f(args...)
end |> esc
end
diff --git a/test/basics.jl b/test/basics.jl
index 36228324c..66c2b950d 100644
--- a/test/basics.jl
+++ b/test/basics.jl
@@ -216,18 +216,18 @@ end
@testset "maketerm" begin
@syms a b c
- @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1))
- @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b)
+ @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1))
+ @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b)
# test that maketerm doesn't hard-code BasicSymbolic subtype
# and is consistent with BasicSymbolic arithmetic operations
- @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c)
- @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0)
- @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3)
+ @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], nothing), (a / b) * c)
+ @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], nothing), 0)
+ @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], nothing), (a * b)^3)
# test that maketerm sets metadata correctly
metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1")
- s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata)
+ s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], metadata)
@test hasmetadata(s, Ctx1)
@test getmetadata(s, Ctx1) == "meta_1"
end
From eaf53e2583397576a315b864e0c6fcfb5db4756d Mon Sep 17 00:00:00 2001
From: a
Date: Sat, 8 Jun 2024 22:51:06 +0100
Subject: [PATCH 04/23] make tests pass
---
src/code.jl | 4 ++--
src/polyform.jl | 4 ++--
src/rewriters.jl | 34 +++++++++-------------------------
3 files changed, 13 insertions(+), 29 deletions(-)
diff --git a/src/code.jl b/src/code.jl
index b1a4db876..81eaacba4 100644
--- a/src/code.jl
+++ b/src/code.jl
@@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
- symtype, similarterm, unsorted_arguments, metadata, isterm, term
+ symtype, unsorted_arguments, metadata, isterm, term
##== state management ==##
@@ -694,7 +694,7 @@ function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
args = map(Base.Fix1(_cse!, mem), arguments(expr))
- t = similarterm(expr, op, args)
+ t = maketerm(typeof(expr), op, args, nothing)
v, dict = mem
update! = let v=v, t=t
diff --git a/src/polyform.jl b/src/polyform.jl
index 8625a09df..a8ba96223 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -250,7 +250,7 @@ function unpolyize(x)
# we need a special makterm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
- Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, m))(x)
+ Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x)
end
function toterm(x::PolyForm)
@@ -302,7 +302,7 @@ function add_divs(x, y)
end
end
-function frac_maketerm(T, f, args, stype, metadata)
+function frac_maketerm(T, f, args, metadata)
# TODO add stype to T?
if f in (*, /, \, +, -)
f(args...)
diff --git a/src/rewriters.jl b/src/rewriters.jl
index 3b3bba5e5..73d199d81 100644
--- a/src/rewriters.jl
+++ b/src/rewriters.jl
@@ -167,11 +167,7 @@ end
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
- maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function
- # that behaves like `similarterm` here, we use `compatmaker` to wrap
- # maketerm-like input to do this, with a warning if similarterm provided
- # we need this workaround to deprecate because similarterm takes value
- # but maketerm only knows the type.
+ maketerm::F
end
function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded}
@@ -183,25 +179,13 @@ end
using .Threads
-function compatmaker(similarterm, maketerm)
- # XXX: delete this and only use maketerm in a future release.
- if similarterm isa Nothing
- function (x, f, args, type=_promote_symtype(f, args); metadata)
- maketerm(typeof(x), f, args, type, metadata)
- end
- else
- Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm)
- similarterm
- end
-end
-function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
- maker = compatmaker(similarterm, maketerm)
- Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
+
+function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
+ Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end
-function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
- maker = compatmaker(similarterm, maketerm)
- Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
+function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
+ Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end
struct PassThrough{C}
@@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
end
if iscall(x)
- x = p.maketerm(x, operation(x), map(PassThrough(p),
- unsorted_arguments(x)), metadata=metadata(x))
+ x = p.maketerm(typeof(x), operation(x), map(PassThrough(p),
+ unsorted_arguments(x)), metadata(x))
end
return ord === :post ? p.rw(x) : x
@@ -245,7 +229,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
- t = p.maketerm(x, operation(x), args, metadata=metadata(x))
+ t = p.maketerm(typeof(x), operation(x), args, metadata(x))
end
return ord === :post ? p.rw(t) : t
else
From 02fde6da3db4bf0c12d1ca929c4110acfaa08d1c Mon Sep 17 00:00:00 2001
From: a
Date: Sun, 9 Jun 2024 12:30:26 +0100
Subject: [PATCH 05/23] update docs
---
docs/src/manual/interface.md | 14 +++++++++++++-
src/interface.jl | 19 -------------------
src/types.jl | 17 +++++++++++++++++
3 files changed, 30 insertions(+), 20 deletions(-)
delete mode 100644 src/interface.jl
diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md
index 899aff658..5a7a4a220 100644
--- a/docs/src/manual/interface.md
+++ b/docs/src/manual/interface.md
@@ -13,6 +13,18 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym
## SymbolicUtils.jl only methods
-`promote_symtype(f, arg_symtypes...)`
+### `symtype(x)`
+
+Returns the symbolic type of `x`. By default this is just `typeof(x)`.
+Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
+specific to numbers (such as commutativity of multiplication). Or such
+rules that may be implemented in the future.
+
+### `issym(x)`
+
+Returns `true` if `x` is a symbol. If true, `nameof` must be defined
+on `x` and must return a Symbol.
+
+### `promote_symtype(f, arg_symtypes...)`
Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`.
diff --git a/src/interface.jl b/src/interface.jl
deleted file mode 100644
index 8cdf20e94..000000000
--- a/src/interface.jl
+++ /dev/null
@@ -1,19 +0,0 @@
-"""
- symtype(x)
-
-Returns the symbolic type of `x`. By default this is just `typeof(x)`.
-Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
-specific to numbers (such as commutativity of multiplication). Or such
-rules that may be implemented in the future.
-"""
-function symtype(x)
- typeof(x)
-end
-
-"""
- issym(x)
-
-Returns `true` if `x` is a symbol. If true, `nameof` must be defined
-on `x` and must return a Symbol.
-"""
-issym(x) = false
diff --git a/src/types.jl b/src/types.jl
index 469634515..9b6cb404b 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -98,6 +98,15 @@ end
###
### Term interface
###
+
+"""
+ symtype(x)
+
+Returns the symbolic type of `x`. By default this is just `typeof(x)`.
+Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
+specific to numbers (such as commutativity of multiplication). Or such
+rules that may be implemented in the future.
+"""
symtype(x) = typeof(x)
symtype(x::Number) = typeof(x)
@inline symtype(::Symbolic{T}) where T = T
@@ -193,8 +202,16 @@ isexpr(s::BasicSymbolic) = !issym(s)
iscall(s::BasicSymbolic) = isexpr(s)
@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false
+
+"""
+ issym(x)
+
+Returns `true` if `x` is a symbol. If true, `nameof` must be defined
+on `x` and must return a Symbol.
+"""
issym(x) = false
issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x)
+
isterm(x) = isa_SymType(Val(:Term), x)
ismul(x) = isa_SymType(Val(:Mul), x)
isadd(x) = isa_SymType(Val(:Add), x)
From c53654ba31de4f0ea5449b90b14c814d022dd6a6 Mon Sep 17 00:00:00 2001
From: a
Date: Sun, 9 Jun 2024 12:38:21 +0100
Subject: [PATCH 06/23] adjust docs
---
docs/src/manual/representation.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/manual/representation.md b/docs/src/manual/representation.md
index 997d33f3a..fea21bf1b 100644
--- a/docs/src/manual/representation.md
+++ b/docs/src/manual/representation.md
@@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep
The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres.
-All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)
+All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)
### Preliminary representation of arithmetic
From ffbb4e001a7f6dfa5702afd8c46e41ad593cfe2e Mon Sep 17 00:00:00 2001
From: a
Date: Tue, 11 Jun 2024 22:19:05 +0200
Subject: [PATCH 07/23] adjust from suggestion
---
src/types.jl | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/src/types.jl b/src/types.jl
index 9b6cb404b..af04d8bfa 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -108,8 +108,8 @@ specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.
"""
symtype(x) = typeof(x)
-symtype(x::Number) = typeof(x)
@inline symtype(::Symbolic{T}) where T = T
+@inline symtype(::Type{<:Symbolic{T}}) where T = T
# We're returning a function pointer
@inline function operation(x::BasicSymbolic)
@@ -558,8 +558,6 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
basicsymbolic(head, args, symtype(T), metadata)
end
-symtype(::Type{<:Symbolic{T}}) where T = T
-
function basicsymbolic(f, args, stype, metadata)
if f isa Symbol
From 55d1e15471d792427a6191c50caa95b681d24b98 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com>
Date: Tue, 11 Jun 2024 22:19:25 +0200
Subject: [PATCH 08/23] Update src/types.jl
Co-authored-by: Bowen S. Zhu
---
src/types.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/types.jl b/src/types.jl
index af04d8bfa..bf94ae3c2 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -100,10 +100,10 @@ end
###
"""
- symtype(x)
+ symtype(x)
-Returns the symbolic type of `x`. By default this is just `typeof(x)`.
-Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
+Returns the numeric type of `x`. By default this is just `typeof(x)`.
+Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules
specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.
"""
From 139b374dac5409dad632a6a191a98daa0a12c98e Mon Sep 17 00:00:00 2001
From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com>
Date: Tue, 11 Jun 2024 22:19:35 +0200
Subject: [PATCH 09/23] Update src/types.jl
Co-authored-by: Bowen S. Zhu
---
src/types.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/types.jl b/src/types.jl
index bf94ae3c2..7a06f321c 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -204,10 +204,10 @@ iscall(s::BasicSymbolic) = isexpr(s)
@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false
"""
- issym(x)
+ issym(x)
-Returns `true` if `x` is a symbol. If true, `nameof` must be defined
-on `x` and must return a Symbol.
+Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined
+on `x` and must return a `Symbol`.
"""
issym(x) = false
issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x)
From d6da38f5a0379ed0930f09a5e3b0d32ea33f2bd9 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com>
Date: Tue, 11 Jun 2024 22:19:50 +0200
Subject: [PATCH 10/23] Update docs/src/manual/interface.md
Co-authored-by: Bowen S. Zhu
---
docs/src/manual/interface.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md
index 5a7a4a220..265d53074 100644
--- a/docs/src/manual/interface.md
+++ b/docs/src/manual/interface.md
@@ -22,8 +22,8 @@ rules that may be implemented in the future.
### `issym(x)`
-Returns `true` if `x` is a symbol. If true, `nameof` must be defined
-on `x` and must return a Symbol.
+Returns `true` if `x` is a `Sym`. If `true`, `nameof` must be defined
+on `x` and must return a `Symbol`.
### `promote_symtype(f, arg_symtypes...)`
From 5c63aba5d86b7957157cb8a447e2b481359df51c Mon Sep 17 00:00:00 2001
From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com>
Date: Tue, 11 Jun 2024 22:20:10 +0200
Subject: [PATCH 11/23] Update docs/src/manual/interface.md
Co-authored-by: Bowen S. Zhu
---
docs/src/manual/interface.md | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md
index 265d53074..552bf7229 100644
--- a/docs/src/manual/interface.md
+++ b/docs/src/manual/interface.md
@@ -15,7 +15,9 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym
### `symtype(x)`
-Returns the symbolic type of `x`. By default this is just `typeof(x)`.
+Returns the
+[numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types)
+of `x`. By default this is just `typeof(x)`.
Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.
From fc435a7b02e5babfa042665560ae26a6d1086161 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli <17289614+0x0f0f0f@users.noreply.github.com>
Date: Tue, 11 Jun 2024 22:20:39 +0200
Subject: [PATCH 12/23] Update docs/src/manual/interface.md
Co-authored-by: Bowen S. Zhu
---
docs/src/manual/interface.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md
index 552bf7229..98518fb52 100644
--- a/docs/src/manual/interface.md
+++ b/docs/src/manual/interface.md
@@ -18,7 +18,7 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym
Returns the
[numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types)
of `x`. By default this is just `typeof(x)`.
-Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
+Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules
specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.
From 7b72cf12d1f646b3cb2db0db11d5e749f3145542 Mon Sep 17 00:00:00 2001
From: a
Date: Tue, 11 Jun 2024 22:44:20 +0200
Subject: [PATCH 13/23] bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 451b8dace..5c1cdcccb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "SymbolicUtils"
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
authors = ["Shashi Gowda"]
-version = "2.0.2"
+version = "2.1.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
From 02b7fa71a1c9db33a652a20f32970e6ea657e4b8 Mon Sep 17 00:00:00 2001
From: a
Date: Tue, 18 Jun 2024 13:09:58 +0200
Subject: [PATCH 14/23] use metatheory rules
---
Project.toml | 1 +
src/rule.jl | 41 ++++++++++++++++++++++-------------------
src/simplify_rules.jl | 15 +++++++++++----
3 files changed, 34 insertions(+), 23 deletions(-)
diff --git a/Project.toml b/Project.toml
index 5c1cdcccb..85e312364 100644
--- a/Project.toml
+++ b/Project.toml
@@ -15,6 +15,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
diff --git a/src/rule.jl b/src/rule.jl
index 9599b5a7d..48164c04d 100644
--- a/src/rule.jl
+++ b/src/rule.jl
@@ -298,24 +298,27 @@ whether the predicate holds or not.
_In the consequent pattern_: Use `(@ctx)` to access the context object on the right hand side
of an expression.
"""
-macro rule(expr)
- @assert expr.head == :call && expr.args[1] == :(=>)
- lhs = expr.args[2]
- rhs = rewrite_rhs(expr.args[3])
- keys = Symbol[]
- lhs_term = makepattern(lhs, keys)
- unique!(keys)
- quote
- $(__source__)
- lhs_pattern = $(lhs_term)
- Rule($(QuoteNode(expr)),
- lhs_pattern,
- matcher(lhs_pattern),
- __MATCHES__ -> $(makeconsequent(rhs)),
- rule_depth($lhs_term))
- end
-end
-
+# macro rule(expr)
+# @assert expr.head == :call && expr.args[1] == :(=>)
+# lhs = expr.args[2]
+# rhs = rewrite_rhs(expr.args[3])
+# keys = Symbol[]
+# lhs_term = makepattern(lhs, keys)
+# unique!(keys)
+# quote
+# $(__source__)
+# lhs_pattern = $(lhs_term)
+# Rule($(QuoteNode(expr)),
+# lhs_pattern,
+# matcher(lhs_pattern),
+# __MATCHES__ -> $(makeconsequent(rhs)),
+# rule_depth($lhs_term))
+# end
+# end
+
+using Metatheory
+using Metatheory: @rule
+using TermInterface: isexpr
"""
@capture ex pattern
@@ -394,7 +397,7 @@ function (acr::ACRule)(term)
else
f = operation(term)
# Assume that the matcher was formed by closing over a term
- if f != operation(r.lhs) # Maybe offer a fallback if m.term errors.
+ if f != operation(r.left) # Maybe offer a fallback if m.term errors.
return nothing
end
diff --git a/src/simplify_rules.jl b/src/simplify_rules.jl
index a612036cb..26a37a73d 100644
--- a/src/simplify_rules.jl
+++ b/src/simplify_rules.jl
@@ -1,4 +1,6 @@
using .Rewriters
+using Metatheory: @rule
+
"""
is_operation(f)
Returns a single argument anonymous function predicate, that returns `true` if and only if
@@ -6,10 +8,15 @@ the argument to the predicate satisfies `iscall` and `operation(x) == f`
"""
is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f)
+const isnotflatplus = isnotflat(+)
+const isnotflattimes = isnotflat(*)
+const needs_sorting_plus = needs_sorting(+)
+const needs_sorting_times = needs_sorting(*)
+
let
CANONICALIZE_PLUS = [
- @rule(~x::isnotflat(+) => flatten_term(+, ~x))
- @rule(~x::needs_sorting(+) => sort_args(+, ~x))
+ @rule(~x::isnotflatplus => flatten_term(+, ~x))
+ @rule(~x::needs_sorting_plus => sort_args(+, ~x))
@ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b)
@acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))
@@ -28,8 +35,8 @@ let
]
CANONICALIZE_TIMES = [
- @rule(~x::isnotflat(*) => flatten_term(*, ~x))
- @rule(~x::needs_sorting(*) => sort_args(*, ~x))
+ @rule(~x::isnotflattimes => flatten_term(*, ~x))
+ @rule(~x::needs_sorting_times => sort_args(*, ~x))
@ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b)
@rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...))
From e9ebd8f56fcc0ddfc7c2510cefd55581f125df9b Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 14:14:03 +0200
Subject: [PATCH 15/23] building
---
Project.toml | 2 +-
src/SymbolicUtils.jl | 2 +-
src/code.jl | 20 ++++++++++----------
src/inspect.jl | 2 +-
src/matchers.jl | 2 +-
src/ordering.jl | 8 ++++----
src/polyform.jl | 16 ++++++++--------
src/rewriters.jl | 8 ++++----
src/rule.jl | 4 ++--
src/simplify.jl | 2 +-
src/substitute.jl | 4 ++--
src/types.jl | 23 +++++++++++------------
src/utils.jl | 18 +++++++++---------
test/basics.jl | 4 ++--
test/rewrite.jl | 12 +++++++-----
15 files changed, 64 insertions(+), 63 deletions(-)
diff --git a/Project.toml b/Project.toml
index 5c1cdcccb..f15941122 100644
--- a/Project.toml
+++ b/Project.toml
@@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
-TermInterface = "1.0.1"
+TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
julia = "1.3"
diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl
index 27c0a7dd0..4bc0eb147 100644
--- a/src/SymbolicUtils.jl
+++ b/src/SymbolicUtils.jl
@@ -19,7 +19,7 @@ using TermInterface
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm
-export operation, arguments, unsorted_arguments, iscall
+export operation, arguments, sorted_arguments, iscall
# Sym, Term,
# Add, Mul and Pow
include("types.jl")
diff --git a/src/code.jl b/src/code.jl
index 81eaacba4..72565b768 100644
--- a/src/code.jl
+++ b/src/code.jl
@@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
- symtype, unsorted_arguments, metadata, isterm, term
+ symtype, sorted_arguments, metadata, isterm, term
##== state management ==##
@@ -115,7 +115,7 @@ function function_to_expr(op, O, st)
(get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing
name = nameof(op)
fun = GlobalRef(NaNMath, name)
- args = map(Base.Fix2(toexpr, st), arguments(O))
+ args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
expr = Expr(:call, fun)
append!(expr.args, args)
return expr
@@ -124,7 +124,7 @@ end
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
out = get(st.rewrites, O, nothing)
out === nothing || return out
- args = map(Base.Fix2(toexpr, st), arguments(O))
+ args = map(Base.Fix2(toexpr, st), sorted_arguments(O))
if length(args) >= 3 && symtype(O) <: Number
x, xs = Iterators.peel(args)
foldl(xs, init=x) do a, b
@@ -138,7 +138,7 @@ function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
end
function function_to_expr(::typeof(^), O, st)
- args = arguments(O)
+ args = sorted_arguments(O)
if length(args) == 2 && args[2] isa Real && args[2] < 0
ex = args[1]
if args[2] == -1
@@ -151,7 +151,7 @@ function function_to_expr(::typeof(^), O, st)
end
function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
- args = arguments(O)
+ args = sorted_arguments(O)
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
end
@@ -183,7 +183,7 @@ function toexpr(O, st)
return expr′
else
!iscall(O) && return O
- args = arguments(O)
+ args = sorted_arguments(O)
return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...)
end
end
@@ -693,7 +693,7 @@ end
function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
- args = map(Base.Fix1(_cse!, mem), arguments(expr))
+ args = map(Base.Fix1(_cse!, mem), sorted_arguments(expr))
t = maketerm(typeof(expr), op, args, nothing)
v, dict = mem
@@ -716,7 +716,7 @@ end
function _cse(exprs::AbstractArray)
letblock = cse(Term{Any}(tuple, vec(exprs)))
- letblock.pairs, reshape(arguments(letblock.body), size(exprs))
+ letblock.pairs, reshape(sorted_arguments(letblock.body), size(exprs))
end
function cse(x::MakeArray)
@@ -744,7 +744,7 @@ end
function cse_state!(state, t)
!iscall(t) && return t
state[t] = Base.get(state, t, 0) + 1
- foreach(x->cse_state!(state, x), unsorted_arguments(t))
+ foreach(x->cse_state!(state, x), arguments(t))
end
function cse_block!(assignments, counter, names, name, state, x)
@@ -759,7 +759,7 @@ function cse_block!(assignments, counter, names, name, state, x)
return sym
end
elseif iscall(x)
- args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x))
+ args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
if isterm(x)
return term(operation(x), args...)
else
diff --git a/src/inspect.jl b/src/inspect.jl
index 42b0b1be5..a61f6f44c 100644
--- a/src/inspect.jl
+++ b/src/inspect.jl
@@ -27,7 +27,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
end
function AbstractTrees.children(x::Symbolic)
- iscall(x) ? arguments(x) : isexpr(x) ? children(x) : ()
+ iscall(x) ? sorted_arguments(x) : isexpr(x) ? children(x) : ()
end
"""
diff --git a/src/matchers.jl b/src/matchers.jl
index 7f4dea537..91a6c1990 100644
--- a/src/matchers.jl
+++ b/src/matchers.jl
@@ -85,7 +85,7 @@ function matcher(segment::Segment)
end
function term_matcher(term)
- matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
+ matchers = (matcher(operation(term)), map(matcher, sorted_arguments(term))...,)
function term_matcher(success, data, bindings)
!islist(data) && return nothing
diff --git a/src/ordering.jl b/src/ordering.jl
index 3417f3f85..6e276412b 100644
--- a/src/ordering.jl
+++ b/src/ordering.jl
@@ -22,7 +22,7 @@ function get_degrees(expr)
((Symbol(expr),) => 1,)
elseif iscall(expr)
op = operation(expr)
- args = arguments(expr)
+ args = sorted_arguments(expr)
if operation(expr) == (^) && args[2] isa Number
return map(get_degrees(args[1])) do (base, pow)
(base => pow * args[2])
@@ -35,7 +35,7 @@ function get_degrees(expr)
_, idx = findmax(x->sum(last.(x), init=0), ds)
return ds[idx]
elseif operation(expr) == (getindex)
- args = arguments(expr)
+ args = sorted_arguments(expr)
return ((Symbol.(args)...,) => 1,)
else
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
@@ -62,7 +62,7 @@ function lexlt(degs1, degs2)
return false # they are equal
end
-_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0
+_arglen(a) = iscall(a) ? length(arguments(a)) : 0
function <ₑ(a::Tuple, b::Tuple)
for (x, y) in zip(a, b)
@@ -81,7 +81,7 @@ function <ₑ(a::BasicSymbolic, b::BasicSymbolic)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
- return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
+ return (operation(a), sorted_arguments(a)...,) <ₑ (operation(b), sorted_arguments(b)...,)
else
return _arglen(a) < _arglen(b)
end
diff --git a/src/polyform.jl b/src/polyform.jl
index a8ba96223..60f59f0e5 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -103,7 +103,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
end
op = operation(x)
- args = arguments(x)
+ args = sorted_arguments(x)
local_polyize(y) = polyize(y, pvar2sym, sym2term, vtype, pow, Fs, recurse)
@@ -343,7 +343,7 @@ end
function add_with_div(x, flatten=true)
(!iscall(x) || operation(x) != (+)) && return x
- aa = unsorted_arguments(x)
+ aa = arguments(x)
!any(a->isdiv(a), aa) && return x # no rewrite necessary
divs = filter(a->isdiv(a), aa)
@@ -381,16 +381,16 @@ end
function needs_div_rules(x)
(isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) ||
- (iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) ||
- (iscall(x) && any(needs_div_rules, unsorted_arguments(x)))
+ (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) ||
+ (iscall(x) && any(needs_div_rules, arguments(x)))
end
function has_div(x)
- return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x)))
+ return isdiv(x) || (iscall(x) && any(has_div, arguments(x)))
end
flatten_pows(xs) = map(xs) do x
- ispow(x) ? Iterators.repeated(arguments(x)...) : (x,)
+ ispow(x) ? Iterators.repeated(sorted_arguments(x)...) : (x,)
end |> Iterators.flatten |> a->collect(Any,a)
coefftype(x::PolyForm) = coefftype(x.p)
@@ -414,8 +414,8 @@ Has optimized processes for `Mul` and `Pow` terms.
function quick_cancel(d)
if ispow(d) && isdiv(d.base)
return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp))
- elseif ismul(d) && any(isdiv, unsorted_arguments(d))
- return prod(unsorted_arguments(d))
+ elseif ismul(d) && any(isdiv, arguments(d))
+ return prod(arguments(d))
elseif isdiv(d)
num, den = quick_cancel(d.num, d.den)
return Div(num, den)
diff --git a/src/rewriters.jl b/src/rewriters.jl
index 73d199d81..f9c9b603a 100644
--- a/src/rewriters.jl
+++ b/src/rewriters.jl
@@ -33,7 +33,7 @@ module Rewriters
using SymbolicUtils: @timer
using TermInterface
-import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype
+import SymbolicUtils: iscall, operation, arguments, sorted_arguments, metadata, node_count, _promote_symtype
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
# Cache of printed rules to speed up @timer
@@ -205,7 +205,7 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
if iscall(x)
x = p.maketerm(typeof(x), operation(x), map(PassThrough(p),
- unsorted_arguments(x)), metadata(x))
+ arguments(x)), metadata(x))
end
return ord === :post ? p.rw(x) : x
@@ -221,14 +221,14 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
x = p.rw(x)
end
if iscall(x)
- _args = map(arguments(x)) do arg
+ _args = map(sorted_arguments(x)) do arg
if node_count(arg) > p.thread_cutoff
Threads.@spawn p(arg)
else
p(arg)
end
end
- args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
+ args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, sorted_arguments(x))
t = p.maketerm(typeof(x), operation(x), args, metadata(x))
end
return ord === :post ? p.rw(t) : t
diff --git a/src/rule.jl b/src/rule.jl
index 9599b5a7d..8025cf819 100644
--- a/src/rule.jl
+++ b/src/rule.jl
@@ -121,7 +121,7 @@ getdepth(r::Rule) = r.depth
function rule_depth(rule, d=0, maxdepth=0)
if iscall(rule)
- maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1)
+ maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in sorted_arguments(rule)), init=1)
elseif rule isa Slot || rule isa Segment
maxdepth = max(d, maxdepth)
end
@@ -399,7 +399,7 @@ function (acr::ACRule)(term)
end
T = symtype(term)
- args = unsorted_arguments(term)
+ args = arguments(term)
itr = acr.sets(eachindex(args), acr.arity)
diff --git a/src/simplify.jl b/src/simplify.jl
index 87bc95954..68fe78f83 100644
--- a/src/simplify.jl
+++ b/src/simplify.jl
@@ -45,6 +45,6 @@ end
has_operation(x, op) = (iscall(x) && (operation(x) == op ||
any(a->has_operation(a, op),
- unsorted_arguments(x))))
+ arguments(x))))
Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)
diff --git a/src/substitute.jl b/src/substitute.jl
index b3f7b668e..8fc980c69 100644
--- a/src/substitute.jl
+++ b/src/substitute.jl
@@ -20,7 +20,7 @@ function substitute(expr, dict; fold=true)
op = substitute(operation(expr), dict; fold=fold)
if fold
canfold = !(op isa Symbolic)
- args = map(unsorted_arguments(expr)) do x
+ args = map(arguments(expr)) do x
x′ = substitute(x, dict; fold=fold)
canfold = canfold && !(x′ isa Symbolic)
x′
@@ -28,7 +28,7 @@ function substitute(expr, dict; fold=true)
canfold && return op(args...)
args
else
- args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
+ args = map(x->substitute(x, dict, fold=fold), arguments(expr))
end
maketerm(typeof(expr),
diff --git a/src/types.jl b/src/types.jl
index 7a06f321c..a45287623 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -126,8 +126,8 @@ end
@inline head(x::BasicSymbolic) = operation(x)
-function arguments(x::BasicSymbolic)
- args = unsorted_arguments(x)
+function sorted_arguments(x::BasicSymbolic)
+ args = arguments(x)
@compactified x::BasicSymbolic begin
Add => @goto ADD
Mul => @goto MUL
@@ -148,9 +148,8 @@ function arguments(x::BasicSymbolic)
return args
end
-unsorted_arguments(x) = arguments(x)
children(x::BasicSymbolic) = arguments(x)
-function unsorted_arguments(x::BasicSymbolic)
+function arguments(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Term => return x.arguments
Add => @goto ADDMUL
@@ -254,8 +253,8 @@ function _isequal(a, b, E)
elseif E === POW
isequal(a.exp, b.exp) && isequal(a.base, b.base)
elseif E === TERM
- a1 = arguments(a)
- a2 = arguments(b)
+ a1 = sorted_arguments(a)
+ a2 = sorted_arguments(b)
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
else
error_on_type()
@@ -296,7 +295,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
!iszero(h) && return h
op = operation(s)
oph = op isa Function ? nameof(op) : op
- h′ = hashvec(arguments(s), hash(oph, salt))
+ h′ = hashvec(sorted_arguments(s), hash(oph, salt))
s.hash[] = h′
return h′
else
@@ -426,7 +425,7 @@ end
@inline function numerators(x)
isdiv(x) && return numerators(x.num)
- iscall(x) && operation(x) === (*) ? arguments(x) : Any[x]
+ iscall(x) && operation(x) === (*) ? sorted_arguments(x) : Any[x]
end
@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]
@@ -545,7 +544,7 @@ function unflatten(t::Symbolic{T}) where{T}
if iscall(t)
f = operation(t)
if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops
- a = arguments(t)
+ a = sorted_arguments(t)
return foldl((x,y) -> Term{T}(f, Any[x, y]), a)
end
end
@@ -662,7 +661,7 @@ const show_simplified = Ref(false)
isnegative(t::Real) = t < 0
function isnegative(t)
if iscall(t) && operation(t) === (*)
- coeff = first(arguments(t))
+ coeff = first(sorted_arguments(t))
return isnegative(coeff)
end
return false
@@ -694,7 +693,7 @@ end
function remove_minus(t)
!iscall(t) && return -t
@assert operation(t) == (*)
- args = arguments(t)
+ args = sorted_arguments(t)
@assert args[1] < 0
Any[-args[1], args[2:end]...]
end
@@ -806,7 +805,7 @@ function show_term(io::IO, t)
end
f = operation(t)
- args = arguments(t)
+ args = sorted_arguments(t)
if symtype(t) <: LiteralReal
show_call(io, f, args)
elseif f === (+)
diff --git a/src/utils.jl b/src/utils.jl
index 812e229fb..fb5ceaa36 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -48,7 +48,7 @@ end
function fold(t)
if iscall(t)
- tt = map(fold, arguments(t))
+ tt = map(fold, sorted_arguments(t))
if !any(x->x isa Symbolic, tt)
# evaluate it
return operation(t)(tt...)
@@ -74,12 +74,12 @@ _isinteger(x) = (x isa Number && isinteger(x)) || (x isa Symbolic && symtype(x)
_isreal(x) = (x isa Number && isreal(x)) || (x isa Symbolic && symtype(x) <: Real)
issortedₑ(args) = issorted(args, lt=<ₑ)
-needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(arguments(x))
+needs_sorting(f) = x -> is_operation(f)(x) && !issortedₑ(sorted_arguments(x))
# are there nested ⋆ terms?
function isnotflat(⋆)
function (x)
- args = arguments(x)
+ args = sorted_arguments(x)
for t in args
if iscall(t) && operation(t) === (⋆)
return true
@@ -137,12 +137,12 @@ x + 2y
```
"""
function flatten_term(⋆, x)
- args = arguments(x)
+ args = sorted_arguments(x)
# flatten nested ⋆
flattened_args = []
for t in args
if iscall(t) && operation(t) === (⋆)
- append!(flattened_args, arguments(t))
+ append!(flattened_args, sorted_arguments(t))
else
push!(flattened_args, t)
end
@@ -151,7 +151,7 @@ function flatten_term(⋆, x)
end
function sort_args(f, t)
- args = arguments(t)
+ args = sorted_arguments(t)
if length(args) < 2
return maketerm(typeof(t), f, args, metadata(t))
elseif length(args) == 2
@@ -182,12 +182,12 @@ Base.length(l::LL) = length(l.v)-l.i+1
Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY
Base.isempty(t::Term) = false
@inline car(t::Term) = operation(t)
-@inline cdr(t::Term) = arguments(t)
+@inline cdr(t::Term) = sorted_arguments(t)
@inline car(v) = iscall(v) ? operation(v) : first(v)
@inline function cdr(v)
if iscall(v)
- arguments(v)
+ sorted_arguments(v)
else
islist(v) ? LL(v, 2) : error("asked cdr of empty")
end
@@ -200,7 +200,7 @@ end
if n === 0
return ll
else
- iscall(ll) ? drop_n(arguments(ll), n-1) : drop_n(cdr(ll), n-1)
+ iscall(ll) ? drop_n(sorted_arguments(ll), n-1) : drop_n(cdr(ll), n-1)
end
end
@inline drop_n(ll::Union{Tuple, AbstractArray}, n) = drop_n(LL(ll, 1), n)
diff --git a/test/basics.jl b/test/basics.jl
index 66c2b950d..58c3cab0d 100644
--- a/test/basics.jl
+++ b/test/basics.jl
@@ -232,7 +232,7 @@ end
@test getmetadata(s, Ctx1) == "meta_1"
end
-toterm(t) = Term{symtype(t)}(operation(t), arguments(t))
+toterm(t) = Term{symtype(t)}(operation(t), sorted_arguments(t))
@testset "diffs" begin
@syms a b c
@@ -279,7 +279,7 @@ end
T = FnType{Tuple{T,S,Int} where {T,S}, Real}
s = Sym{T}(:t)
@syms a b c::Int
- @test isequal(arguments(s(a, b, c)), [a, b, c])
+ @test isequal(sorted_arguments(s(a, b, c)), [a, b, c])
end
@testset "div" begin
diff --git a/test/rewrite.jl b/test/rewrite.jl
index 3bb2621e3..ccc754141 100644
--- a/test/rewrite.jl
+++ b/test/rewrite.jl
@@ -1,5 +1,7 @@
@syms a b c
+using Metatheory
+
@testset "Equality" begin
@eqtest a == a
@eqtest a != b
@@ -65,7 +67,7 @@ using SymbolicUtils: @capture
ret = @capture (a + b) (+)(~~z)
@test ret
@test @isdefined z
- @test all(z .=== arguments(a + b))
+ @test all(z .=== sorted_arguments(a + b))
#a more typical way to use the @capture macro
@@ -84,24 +86,24 @@ end
ex1 = ex + c
@test SymbolicUtils.isterm(ex1)
- @test getmetadata(arguments(ex1)[1], MetaData) == :metadata
+ @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata
ex = a
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex + b
- @test getmetadata(arguments(ex1)[1], MetaData) == :metadata
+ @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata
ex = a * b
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex * c
@test SymbolicUtils.isterm(ex1)
- @test getmetadata(arguments(ex1)[1], MetaData) == :metadata
+ @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata
ex = a
ex = setmetadata(ex, MetaData, :metadata)
ex1 = ex * b
- @test getmetadata(arguments(ex1)[1], MetaData) == :metadata
+ @test getmetadata(sorted_arguments(ex1)[1], MetaData) == :metadata
end
\ No newline at end of file
From e212a2fa837261035278f6674fad6b08ae596484 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 14:17:37 +0200
Subject: [PATCH 16/23] adjust some tests
---
src/SymbolicUtils.jl | 2 +-
src/polyform.jl | 5 +++--
src/types.jl | 6 +++---
3 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl
index 4bc0eb147..4bbb938e0 100644
--- a/src/SymbolicUtils.jl
+++ b/src/SymbolicUtils.jl
@@ -17,7 +17,7 @@ import Base: +, -, *, /, //, \, ^, ImmutableDict
using ConstructionBase
using TermInterface
import TermInterface: iscall, isexpr, head, children,
- operation, arguments, metadata, maketerm
+ operation, arguments, metadata, maketerm, sorted_arguments
export operation, arguments, sorted_arguments, iscall
# Sym, Term,
diff --git a/src/polyform.jl b/src/polyform.jl
index 60f59f0e5..235b46393 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -185,7 +185,8 @@ end
head(::PolyForm) = PolyForm
operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+)
-function arguments(x::PolyForm{T}) where {T}
+TermInterface.sorted_arguments(x::PolyForm{T}) = arguments(t)
+function TermInterface.arguments(x::PolyForm{T}) where {T}
function is_var(v)
MP.nterms(v) == 1 &&
@@ -229,7 +230,7 @@ function arguments(x::PolyForm{T}) where {T}
PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts]
end
end
-children(x::PolyForm) = [operation(x); arguments(x)]
+children(x::PolyForm) = arguments(x)
Base.show(io::IO, x::PolyForm) = show_term(io, x)
diff --git a/src/types.jl b/src/types.jl
index a45287623..4e0ddb7f3 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -126,7 +126,7 @@ end
@inline head(x::BasicSymbolic) = operation(x)
-function sorted_arguments(x::BasicSymbolic)
+function TermInterface.sorted_arguments(x::BasicSymbolic)
args = arguments(x)
@compactified x::BasicSymbolic begin
Add => @goto ADD
@@ -148,8 +148,8 @@ function sorted_arguments(x::BasicSymbolic)
return args
end
-children(x::BasicSymbolic) = arguments(x)
-function arguments(x::BasicSymbolic)
+TermInterface.children(x::BasicSymbolic) = arguments(x)
+function TermInterface.arguments(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Term => return x.arguments
Add => @goto ADDMUL
From 6ba21a180d09f88d04be0a346351878908418b01 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 14:18:08 +0200
Subject: [PATCH 17/23] remove extra method
---
src/polyform.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/polyform.jl b/src/polyform.jl
index 235b46393..89736b1ec 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -185,7 +185,6 @@ end
head(::PolyForm) = PolyForm
operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+)
-TermInterface.sorted_arguments(x::PolyForm{T}) = arguments(t)
function TermInterface.arguments(x::PolyForm{T}) where {T}
function is_var(v)
From dc33ea7b821e75b6c5f801a35cd5ace131563dad Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 14:27:29 +0200
Subject: [PATCH 18/23] make tests green
---
test/rewrite.jl | 2 --
1 file changed, 2 deletions(-)
diff --git a/test/rewrite.jl b/test/rewrite.jl
index ccc754141..8f4304ace 100644
--- a/test/rewrite.jl
+++ b/test/rewrite.jl
@@ -1,7 +1,5 @@
@syms a b c
-using Metatheory
-
@testset "Equality" begin
@eqtest a == a
@eqtest a != b
From 731d4e8de039a0c7ac6050e6532e6943b6c767a4 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 17:01:20 +0200
Subject: [PATCH 19/23] adapt maketerm for promote_symtype
---
src/polyform.jl | 3 ++-
src/types.jl | 10 +++++++++-
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/src/polyform.jl b/src/polyform.jl
index 89736b1ec..9450da554 100644
--- a/src/polyform.jl
+++ b/src/polyform.jl
@@ -176,7 +176,8 @@ iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true
function maketerm(t::Type{<:PolyForm}, f, args, metadata)
- basicsymbolic(t, f, args, metadata)
+ # TODO: this looks uncovered.
+ basicsymbolic(f, args, nothing, metadata)
end
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata)
f(args...)
diff --git a/src/types.jl b/src/types.jl
index 4e0ddb7f3..ea375a725 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -554,7 +554,15 @@ end
unflatten(t) = t
function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
- basicsymbolic(head, args, symtype(T), metadata)
+ st = symtype(T)
+ pst = _promote_symtype(head, args)
+ # Use promoted symtype only if not a subtype of the existing symtype of T.
+ # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])`
+ # Where the result would have a symtype of Bool.
+ # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609
+ # TODO this should be optimized.
+ new_st = (pst === Any || pst <: st) ? st : pst
+ basicsymbolic(head, args, new_st, metadata)
end
From db0ae467dc1853723ff76056ad2b249abb3faa3c Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Mon, 24 Jun 2024 17:44:35 +0200
Subject: [PATCH 20/23] hacky fix to promote symtype
---
src/types.jl | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/src/types.jl b/src/types.jl
index ea375a725..b88304415 100644
--- a/src/types.jl
+++ b/src/types.jl
@@ -561,7 +561,13 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
# Where the result would have a symtype of Bool.
# Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609
# TODO this should be optimized.
- new_st = (pst === Any || pst <: st) ? st : pst
+ new_st = if pst === Bool
+ pst
+ elseif pst === Any || (st === Number && pst <: st)
+ st
+ else
+ pst
+ end
basicsymbolic(head, args, new_st, metadata)
end
From c619ea9400c86155d7a1b9a324fd3a69c04d81a3 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Wed, 26 Jun 2024 12:19:57 +0200
Subject: [PATCH 21/23] iscall
---
src/SymbolicUtils.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl
index 5b4fe150a..ef9b81f33 100644
--- a/src/SymbolicUtils.jl
+++ b/src/SymbolicUtils.jl
@@ -19,7 +19,7 @@ using TermInterface
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments
-const istree = iscalls
+const istree = iscall
Base.@deprecate_binding istree iscall
export istree, operation, arguments, sorted_arguments, similarterm, iscall
# Sym, Term,
From e338734b776259d661f23a5395b5775df09599d9 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Wed, 26 Jun 2024 13:33:06 +0200
Subject: [PATCH 22/23] add egraphs example
---
test/egraphs.jl | 197 ++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 197 insertions(+)
create mode 100644 test/egraphs.jl
diff --git a/test/egraphs.jl b/test/egraphs.jl
new file mode 100644
index 000000000..a5801cdaa
--- /dev/null
+++ b/test/egraphs.jl
@@ -0,0 +1,197 @@
+using Metatheory
+using SymbolicUtils
+const SU = SymbolicUtils
+using SymbolicUtils: Symbolic, BasicSymbolic, unflatten, toterm, Term
+using SymbolicUtils: monadic, diadic
+using InteractiveUtils
+
+EGraphs.preprocess(t::Symbolic) = toterm(unflatten(t))
+
+"""
+Equational rewrite rules for optimizing expressions
+"""
+opt_theory = @theory a b c x y z begin
+ a + (b + c) == (a + b) + c
+ a * (b * c) == (a * b) * c
+ x + 0 --> x
+ a + b == b + a
+ a - a => 0 # is it ok?
+
+ 0 - x --> -x
+
+ a * b == b * a
+ a * x + a * y == a*(x+y)
+ -1 * a --> -a
+ a + (-1 * b) == a - b
+ x * 1 --> x
+ x * 0 --> 0
+ x/x --> 1
+ # fraction rules
+ x^-1 == 1/x
+ 1/x * a == a/x # is this needed?
+ x / (x / y) --> y
+ x * (y / z) == (x * y) / z
+ (a/b) + (c/b) --> (a+c)/b
+ (a / b) / c == a/(b*c)
+
+ # TODO prohibited rule
+ x / x --> 1
+
+ # pow rules
+ a * a == a^2
+ (a^b)^c == a^(b*c)
+ a^b * a^c == a^(b+c)
+ a^b / a^c == a^(b-c)
+ (a*b)^c == a^c * b^c
+
+ # logarithmic rules
+ # TODO variables are non-zero
+ log(x::Number) => log(x)
+ log(x * y) == log(x) + log(y)
+ log(x / y) == log(x) - log(y)
+ log(x^y) == y * log(x)
+ x^(log(y)) == y^(log(x))
+
+ # trig functions
+ sin(x)/cos(x) == tan(x)
+ cos(x)/sin(x) == cot(x)
+ sin(x)^2 + cos(x)^2 --> 1
+ sin(2a) == 2sin(a)cos(a)
+
+ sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y)
+ # hyperbolic trigonometric
+ # are these optimizing at all? dont think so
+ # sinh(x) == (ℯ^x - ℯ^(-x))/2
+ # csch(x) == 1/sinh(x)
+ # cosh(x) == (ℯ^x + ℯ^(-x))/2
+ # sech(x) == 1/cosh(x)
+ # sech(x) == 2/(ℯ^x + ℯ^(-x))
+ # tanh(x) == sinh(x)/cosh(x)
+ # tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x))
+ # coth(x) == 1/tanh(x)
+ # coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x))
+
+ # cosh(x)^2 - sinh(x)^2 --> 1
+ # tanh(x)^2 + sech(x)^2 --> 1
+ # coth(x)^2 - csch(x)^2 --> 1
+
+ # asinh(z) == log(z + √(z^2 + 1))
+ # acosh(z) == log(z + √(z^2 - 1))
+ # atanh(z) == log((1+z)/(1-z))/2
+ # acsch(z) == log((1+√(1+z^2)) / z )
+ # asech(z) == log((1 + √(1-z^2)) / z )
+ # acoth(z) == log( (z+1)/(z-1) )/2
+
+ # folding
+ x::Number * y::Number => x*y
+ x::Number + y::Number => x+y
+ x::Number / y::Number => x/y
+ x::Number - y::Number => x-y
+end
+# opt_theory = @theory a b c x y begin
+# a * x == x * a
+# a * x + a * y == a*(x+y)
+# -1 * a == -a
+# a + -b --> a - b
+# -b + a --> b - a
+# end
+
+
+# See
+# * https://latkin.org/blog/2014/11/09/a-simple-benchmark-of-various-math-operations/
+# * https://streamhpc.com/blog/2012-07-16/how-expensive-is-an-operation-on-a-cpu/
+# * https://github.com/triscale-innov/GFlops.jl
+# Measure the cost of expressions in terms of number of ASM instructions
+
+const op_costs = Dict()
+
+const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)]
+
+const io = IOBuffer()
+
+for f in vcat(monadic, [-])
+ z = get!(op_costs, nameof(f), Dict())
+ for (t, at) in types
+ try
+ InteractiveUtils.code_native(io, f, (t,))
+ catch e
+ z[(t,)] = z[(at,)] = 1
+ continue
+ end
+ str = String(take!(io))
+ z[(t,)] = z[(at,)] = length(split(str, "\n"))
+ end
+end
+
+for f in vcat(diadic, [+, -, *, /, //, ^])
+ z = get!(op_costs, nameof(f), Dict())
+ for (t1, at1) in types, (t2, at2) in types
+ try
+ InteractiveUtils.code_native(io, f, (t1, t2))
+ catch e
+ z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1
+ continue
+ end
+ str = String(take!(io))
+ z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n"))
+ end
+end
+
+function getopcost(f::Function, types::Tuple)
+ sym = nameof(f)
+ if haskey(op_costs, sym) && haskey(op_costs[sym], types)
+ return op_costs[sym][types]
+ end
+
+ # print("$f $types | ")
+ io = IOBuffer()
+ try
+ InteractiveUtils.code_native(io, f, types)
+ catch e
+ op_costs[sym][types] = 1
+ return 1
+ end
+ str = String(take!(io))
+ c = length(split(str, "\n"))
+ !haskey(op_costs, sym) && (op_costs[sym] = Dict())
+ op_costs[sym][types] = c
+end
+
+getopcost(f, types::Tuple) = get(get(op_costs, f, Dict()), types, 1)
+
+function costfun(n::VecExpr, op, children_costs::Vector{Float64})
+ v_isexpr(n) || return 1
+ # types = Tuple(map(x -> getdata(g[x], SymtypeAnalysis, Real), args))
+ types = Tuple([Float64 for i in 1:v_arity(n)])
+ opc = getopcost(op, types)
+ opc + sum(children_costs)
+end
+
+denoisescalars(x, atol=1e-11) = Postwalk(Chain([
+ # 0 - x --> -x
+ @acrule *(~x::Real, sin(~y)) => 0 where isapprox(x, 0; atol=atol)
+ @acrule *(~x::Real, cos(~y)) => 0 where isapprox(x, 0; atol=atol)
+ @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol)
+ @acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol)
+]))(x)
+
+function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...)
+ # ex = simplify(denoisescalars(ex, atol))
+ # println(ex)
+ # readline()
+
+ g = EGraph{BasicSymbolic}(ex)
+
+ # display(g.classes);println();
+
+ report = saturate!(g, opt_theory, params)
+ verbose && @info report
+ extr = extract!(g, costfun)
+ return extr
+end
+
+@syms x y z
+
+t = Term(+, [Term(*, [z, x]), Term(*, [z, y])])
+
+optimize(t)
\ No newline at end of file
From 0c2a4505c61f34c2a18bf21089dc215115025c94 Mon Sep 17 00:00:00 2001
From: Alessandro Cheli
Date: Thu, 27 Jun 2024 09:47:12 +0200
Subject: [PATCH 23/23] adjust e-graph integration
---
test/egraphs.jl | 109 ++++++++++++++++++++++++------------------------
test/rewrite.jl | 2 +
2 files changed, 56 insertions(+), 55 deletions(-)
diff --git a/test/egraphs.jl b/test/egraphs.jl
index a5801cdaa..d139d316c 100644
--- a/test/egraphs.jl
+++ b/test/egraphs.jl
@@ -61,26 +61,26 @@ opt_theory = @theory a b c x y z begin
sin(x)*cos(y) - cos(x)*sin(y) --> sin(x - y)
# hyperbolic trigonometric
# are these optimizing at all? dont think so
- # sinh(x) == (ℯ^x - ℯ^(-x))/2
- # csch(x) == 1/sinh(x)
- # cosh(x) == (ℯ^x + ℯ^(-x))/2
- # sech(x) == 1/cosh(x)
- # sech(x) == 2/(ℯ^x + ℯ^(-x))
- # tanh(x) == sinh(x)/cosh(x)
- # tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x))
- # coth(x) == 1/tanh(x)
- # coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x))
-
- # cosh(x)^2 - sinh(x)^2 --> 1
- # tanh(x)^2 + sech(x)^2 --> 1
- # coth(x)^2 - csch(x)^2 --> 1
-
- # asinh(z) == log(z + √(z^2 + 1))
- # acosh(z) == log(z + √(z^2 - 1))
- # atanh(z) == log((1+z)/(1-z))/2
- # acsch(z) == log((1+√(1+z^2)) / z )
- # asech(z) == log((1 + √(1-z^2)) / z )
- # acoth(z) == log( (z+1)/(z-1) )/2
+ sinh(x) == (ℯ^x - ℯ^(-x))/2
+ csch(x) == 1/sinh(x)
+ cosh(x) == (ℯ^x + ℯ^(-x))/2
+ sech(x) == 1/cosh(x)
+ sech(x) == 2/(ℯ^x + ℯ^(-x))
+ tanh(x) == sinh(x)/cosh(x)
+ tanh(x) == (ℯ^x - ℯ^(-x))/(ℯ^x + ℯ^(-x))
+ coth(x) == 1/tanh(x)
+ coth(x) == (ℯ^x + ℯ^-x)/(ℯ^x - ℯ^(-x))
+
+ cosh(x)^2 - sinh(x)^2 --> 1
+ tanh(x)^2 + sech(x)^2 --> 1
+ coth(x)^2 - csch(x)^2 --> 1
+
+ asinh(z) == log(z + √(z^2 + 1))
+ acosh(z) == log(z + √(z^2 - 1))
+ atanh(z) == log((1+z)/(1-z))/2
+ acsch(z) == log((1+√(1+z^2)) / z )
+ asech(z) == log((1 + √(1-z^2)) / z )
+ acoth(z) == log( (z+1)/(z-1) )/2
# folding
x::Number * y::Number => x*y
@@ -88,13 +88,6 @@ opt_theory = @theory a b c x y z begin
x::Number / y::Number => x/y
x::Number - y::Number => x-y
end
-# opt_theory = @theory a b c x y begin
-# a * x == x * a
-# a * x + a * y == a*(x+y)
-# -1 * a == -a
-# a + -b --> a - b
-# -b + a --> b - a
-# end
# See
@@ -103,41 +96,46 @@ end
# * https://github.com/triscale-innov/GFlops.jl
# Measure the cost of expressions in terms of number of ASM instructions
-const op_costs = Dict()
-const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)]
-
-const io = IOBuffer()
-
-for f in vcat(monadic, [-])
- z = get!(op_costs, nameof(f), Dict())
- for (t, at) in types
- try
- InteractiveUtils.code_native(io, f, (t,))
- catch e
- z[(t,)] = z[(at,)] = 1
- continue
+function make_op_costs()
+ const op_costs = Dict()
+
+ const types = [(Int64, Integer), (Float64, Real), (ComplexF64, Complex)]
+
+ const io = IOBuffer()
+
+ for f in vcat(monadic, [-])
+ z = get!(op_costs, nameof(f), Dict())
+ for (t, at) in types
+ try
+ InteractiveUtils.code_native(io, f, (t,))
+ catch e
+ z[(t,)] = z[(at,)] = 1
+ continue
+ end
+ str = String(take!(io))
+ z[(t,)] = z[(at,)] = length(split(str, "\n"))
end
- str = String(take!(io))
- z[(t,)] = z[(at,)] = length(split(str, "\n"))
end
-end
-
-for f in vcat(diadic, [+, -, *, /, //, ^])
- z = get!(op_costs, nameof(f), Dict())
- for (t1, at1) in types, (t2, at2) in types
- try
- InteractiveUtils.code_native(io, f, (t1, t2))
- catch e
- z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1
- continue
+
+ for f in vcat(diadic, [+, -, *, /, //, ^])
+ z = get!(op_costs, nameof(f), Dict())
+ for (t1, at1) in types, (t2, at2) in types
+ try
+ InteractiveUtils.code_native(io, f, (t1, t2))
+ catch e
+ z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = 1
+ continue
+ end
+ str = String(take!(io))
+ z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n"))
end
- str = String(take!(io))
- z[(t1, t2)] = z[(at1, at2)] = z[(at1, t2)] = z[(t1, at2)] = length(split(str, "\n"))
end
+
+ op_costs
end
-function getopcost(f::Function, types::Tuple)
+function getopcost(op_costs, f::Function, types::Tuple)
sym = nameof(f)
if haskey(op_costs, sym) && haskey(op_costs[sym], types)
return op_costs[sym][types]
@@ -175,6 +173,7 @@ denoisescalars(x, atol=1e-11) = Postwalk(Chain([
@acrule +(~x::Real, ~y) => y where isapprox(x, 0; atol=atol)
]))(x)
+const op_costs = make_op_costs()
function optimize(ex::Symbolic; params=SaturationParams(), atol=1e-13, verbose=false, kws...)
# ex = simplify(denoisescalars(ex, atol))
# println(ex)
diff --git a/test/rewrite.jl b/test/rewrite.jl
index 8f4304ace..ccc754141 100644
--- a/test/rewrite.jl
+++ b/test/rewrite.jl
@@ -1,5 +1,7 @@
@syms a b c
+using Metatheory
+
@testset "Equality" begin
@eqtest a == a
@eqtest a != b