From 39d8791955391894f63820bf2a0f5b031082e707 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 12 Jun 2024 21:36:33 -0700 Subject: [PATCH 001/140] Add Expronicon v0.8 --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 7712058c8..c7d9c11df 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -35,6 +36,7 @@ ConstructionBase = "1.1" DataStructures = "0.18" DocStringExtensions = "0.8, 0.9" DynamicPolynomials = "0.5" +Expronicon = "~0.8" IfElse = "0.1" LabelledArrays = "1.5" MultivariatePolynomials = "0.5" From dd7a7b2ef7f177b962f65f8c37f11abf626b4fbd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 13:26:16 -0700 Subject: [PATCH 002/140] Add MLStyle for pattern matching --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index c7d9c11df..cc03478b8 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -39,6 +40,7 @@ DynamicPolynomials = "0.5" Expronicon = "~0.8" IfElse = "0.1" LabelledArrays = "1.5" +MLStyle = "0.4" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1" Setfield = "0.7, 0.8, 1" From e15c015148a872803f1d0bdd65ca235634ea2b27 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 13:29:59 -0700 Subject: [PATCH 003/140] Remove Unityper --- Project.toml | 2 -- src/SymbolicUtils.jl | 1 - 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index cc03478b8..86bcea4cc 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" [compat] AbstractTrees = "0.4" @@ -49,7 +48,6 @@ StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" TermInterface = "0.4" TimerOutputs = "0.5" -Unityper = "0.1.2" julia = "1.3" [extras] diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index d748f46c6..949bdcfbf 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -7,7 +7,6 @@ using DocStringExtensions export @syms, term, showraw, hasmetadata, getmetadata, setmetadata -using Unityper using TermInterface using DataStructures using Setfield From ffaefd8895ec7aca3ab48fbd1fe77b97b7116356 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 13:36:35 -0700 Subject: [PATCH 004/140] Add `CONST` enum member value --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 3abf6c139..2d11556ac 100644 --- a/src/types.jl +++ b/src/types.jl @@ -8,7 +8,7 @@ abstract type Symbolic{T} end ### Uni-type design ### -@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV +@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV CONST const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} const NO_METADATA = nothing From 4eec0efc8d84b2a9044ea55735a56a4fa672f0e9 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 13:53:09 -0700 Subject: [PATCH 005/140] Migrate `BasicSymbolic` to Expronicon ADTs --- src/SymbolicUtils.jl | 3 ++ src/types.jl | 82 +++++++++++++++++++------------------------- 2 files changed, 39 insertions(+), 46 deletions(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 949bdcfbf..b15b4d061 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -21,6 +21,9 @@ import TermInterface: iscall, isexpr, issym, symtype, head, children, const istree = iscall Base.@deprecate_binding istree iscall export istree, operation, arguments, unsorted_arguments, similarterm, iscall + +using Base: RefValue +using Expronicon.ADT: @adt # Sym, Term, # Add, Mul and Pow include("types.jl") diff --git a/src/types.jl b/src/types.jl index 2d11556ac..8a844658e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,66 +1,56 @@ -#------------------- -#-------------------- -#### Symbolic -#-------------------- abstract type Symbolic{T} end -### -### Uni-type design -### - @enum ExprType::UInt8 SYM TERM ADD MUL POW DIV CONST -const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} +const Metadata = Union{Nothing, Base.ImmutableDict{DataType, Any}} const NO_METADATA = nothing +const EMPTY_HASH = UInt(0) -sdict(kv...) = Dict{Any, Any}(kv...) +sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) -using Base: RefValue -const EMPTY_ARGS = [] -const EMPTY_HASH = RefValue(UInt(0)) -const NOT_SORTED = RefValue(false) -const EMPTY_DICT = sdict() -const EMPTY_DICT_T = typeof(EMPTY_DICT) -@compactify show_methods=false begin - @abstract struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA +@adt BasicSymbolicImpl begin + struct Sym + name::Symbol = :OOF end - struct Sym{T} <: BasicSymbolic{T} - name::Symbol = :OOF + struct Term + f::Any = identity + arguments::Vector{BasicSymbolic} = BasicSymbolic[] end - struct Term{T} <: BasicSymbolic{T} - f::Any = identity # base/num if Pow; issorted if Add/Dict - arguments::Vector{Any} = EMPTY_ARGS - hash::RefValue{UInt} = EMPTY_HASH + struct Add + coeff::BasicSymbolic + dict::Dict{BasicSymbolic, Any} = sdict() + arguments::Vector{BasicSymbolic} = BasicSymbolic[] + issorted::RefValue{Bool} = Ref(false) end - struct Mul{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED + struct Mul + coeff::BasicSymbolic + dict::Dict{BasicSymbolic, Any} + arguments::Vector{BasicSymbolic} = BasicSymbolic[] + issorted::RefValue{Bool} = Ref(false) end - struct Add{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED + struct Div + num::BasicSymbolic + den::BasicSymbolic + simplified::RefValue{Bool} = Ref(false) + arguments::Vector{BasicSymbolic} = BasicSymbolic[] end - struct Div{T} <: BasicSymbolic{T} - num::Any = 1 - den::Any = 1 - simplified::Bool = false - arguments::Vector{Any} = EMPTY_ARGS + struct Pow + base::BasicSymbolic + exp::BasicSymbolic + arguments::Vector{BasicSymbolic} = BasicSymbolic[] end - struct Pow{T} <: BasicSymbolic{T} - base::Any = 1 - exp::Any = 1 - arguments::Vector{Any} = EMPTY_ARGS + struct Const + val::Any end end +Base.@kwdef struct BasicSymbolic{T} <: Symbolic{T} + x::BasicSymbolicImpl + metadata::Metadata = NO_METADATA + hash::RefValue{UInt} = Ref(EMPTY_HASH) +end + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end From 989907b6314263a6855d1dc13515ecaaa4d12459 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 13:55:35 -0700 Subject: [PATCH 006/140] Replace `Unityper.@compactified` with `MLStyle.@match` --- src/SymbolicUtils.jl | 1 + src/types.jl | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index b15b4d061..4d351a244 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -24,6 +24,7 @@ export istree, operation, arguments, unsorted_arguments, similarterm, iscall using Base: RefValue using Expronicon.ADT: @adt +using MLStyle: @match # Sym, Term, # Add, Mul and Pow include("types.jl") diff --git a/src/types.jl b/src/types.jl index 8a844658e..88b32c4f1 100644 --- a/src/types.jl +++ b/src/types.jl @@ -56,13 +56,14 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) end function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + @match x::BasicSymbolic begin Term => TERM Add => ADD Mul => MUL Div => DIV Pow => POW Sym => SYM + Const => CONST _ => error_on_type() end end @@ -93,7 +94,7 @@ symtype(x::Number) = typeof(x) # We're returning a function pointer @inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + @match x::BasicSymbolic begin Term => x.f Add => (+) Mul => (*) @@ -108,7 +109,7 @@ end function arguments(x::BasicSymbolic) args = unsorted_arguments(x) - @compactified x::BasicSymbolic begin + @match x::BasicSymbolic begin Add => @goto ADD Mul => @goto MUL _ => return args @@ -131,7 +132,7 @@ end unsorted_arguments(x) = arguments(x) children(x::BasicSymbolic) = arguments(x) function unsorted_arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin + @match x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL Mul => @goto ADDMUL From 608e23ac231305f3ae24484c289a81f404038ffc Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 14:29:27 -0700 Subject: [PATCH 007/140] Adapt functions to new struct definition --- src/types.jl | 54 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/src/types.jl b/src/types.jl index 88b32c4f1..3b95ca424 100644 --- a/src/types.jl +++ b/src/types.jl @@ -46,7 +46,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) end Base.@kwdef struct BasicSymbolic{T} <: Symbolic{T} - x::BasicSymbolicImpl + impl::BasicSymbolicImpl metadata::Metadata = NO_METADATA hash::RefValue{UInt} = Ref(EMPTY_HASH) end @@ -56,7 +56,7 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) end function exprtype(x::BasicSymbolic) - @match x::BasicSymbolic begin + @match x.impl begin Term => TERM Add => ADD Mul => MUL @@ -71,6 +71,7 @@ end # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") +@noinline error_const() = error("Const doesn't have a operation or arguments!") @noinline error_property(E, s) = error("$E doesn't have field $s") # We can think about bits later @@ -94,13 +95,14 @@ symtype(x::Number) = typeof(x) # We're returning a function pointer @inline function operation(x::BasicSymbolic) - @match x::BasicSymbolic begin + @match x.impl begin Term => x.f Add => (+) Mul => (*) Div => (/) Pow => (^) Sym => error_sym() + Const => error_const() _ => error_on_type() end end @@ -109,7 +111,7 @@ end function arguments(x::BasicSymbolic) args = unsorted_arguments(x) - @match x::BasicSymbolic begin + @match x.impl begin Add => @goto ADD Mul => @goto MUL _ => return args @@ -132,50 +134,51 @@ end unsorted_arguments(x) = arguments(x) children(x::BasicSymbolic) = arguments(x) function unsorted_arguments(x::BasicSymbolic) - @match x::BasicSymbolic begin + @match x.impl begin Term => return x.arguments Add => @goto ADDMUL Mul => @goto ADDMUL Div => @goto DIV Pow => @goto POW Sym => error_sym() + Const => error_const() _ => error_on_type() end @label ADDMUL E = exprtype(x) - args = x.arguments + args = x.impl.arguments isempty(args) || return args - siz = length(x.dict) - idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff) + siz = length(x.impl.dict) + idcoeff = E === ADD ? iszero(x.impl.coeff) : isone(x.impl.coeff) sizehint!(args, idcoeff ? siz : siz + 1) - idcoeff || push!(args, x.coeff) + idcoeff || push!(args, x.impl.coeff) if isadd(x) - for (k, v) in x.dict + for (k, v) in x.impl.dict push!(args, applicable(*,k,v) ? k*v : maketerm(k, *, [k, v])) end else # MUL - for (k, v) in x.dict + for (k, v) in x.impl.dict push!(args, unstable_pow(k, v)) end end return args @label DIV - args = x.arguments + args = x.impl.arguments isempty(args) || return args sizehint!(args, 2) - push!(args, x.num) - push!(args, x.den) + push!(args, x.impl.num) + push!(args, x.impl.den) return args @label POW - args = x.arguments + args = x.impl.arguments isempty(args) || return args sizehint!(args, 2) - push!(args, x.base) - push!(args, x.exp) + push!(args, x.impl.base) + push!(args, x.impl.exp) return args end @@ -220,15 +223,17 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict) + coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict) elseif E === DIV - isequal(a.num, b.num) && isequal(a.den, b.den) + isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW - isequal(a.exp, b.exp) && isequal(a.base, b.base) + isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base) elseif E === TERM a1 = arguments(a) a2 = arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) + elseif E === CONST + isequal(a.impl.val, b.impl.val) else error_on_type() end @@ -246,6 +251,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt +const COS_SALT = 0xdc3d6b8f18b75e3c % UInt function Base.hash(s::BasicSymbolic, salt::UInt)::UInt E = exprtype(s) if E === SYM @@ -255,13 +261,13 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt))) + h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt))) s.hash[] = h′ return h′ elseif E === DIV - return hash(s.num, hash(s.den, salt ⊻ DIV_SALT)) + return hash(s.impl.num, hash(s.impl.den, salt ⊻ DIV_SALT)) elseif E === POW - hash(s.exp, hash(s.base, salt ⊻ POW_SALT)) + hash(s.impl.exp, hash(s.impl.base, salt ⊻ POW_SALT)) elseif E === TERM !iszero(salt) && return hash(hash(s, zero(UInt)), salt) h = s.hash[] @@ -271,6 +277,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h′ = hashvec(arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ + elseif E === CONST + return hash(s.impl.val, salt ⊻ COS_SALT) else error_on_type() end From 93e394271b51ad6551606024cf7ed12b03960790 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 17:18:49 -0700 Subject: [PATCH 008/140] Use `@kwdef` instead of `Base.@kwdef` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 3b95ca424..2dcf0c325 100644 --- a/src/types.jl +++ b/src/types.jl @@ -45,7 +45,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) end end -Base.@kwdef struct BasicSymbolic{T} <: Symbolic{T} +@kwdef struct BasicSymbolic{T} <: Symbolic{T} impl::BasicSymbolicImpl metadata::Metadata = NO_METADATA hash::RefValue{UInt} = Ref(EMPTY_HASH) From 6e5f1706fd80890527e059120afcd8004e6def13 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 19:19:16 -0700 Subject: [PATCH 009/140] Remove the default value for `Sym.name` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 2dcf0c325..cf44c0f9d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -11,7 +11,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) @adt BasicSymbolicImpl begin struct Sym - name::Symbol = :OOF + name::Symbol end struct Term f::Any = identity From a7fd87a00e7033b0f5ea083e0ebf1f2f47dd9c6e Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 19:22:57 -0700 Subject: [PATCH 010/140] Update `Sym` constructor --- src/types.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index cf44c0f9d..417ff1470 100644 --- a/src/types.jl +++ b/src/types.jl @@ -288,8 +288,9 @@ end ### Constructors ### -function Sym{T}(name::Symbol; kw...) where T - Sym{T}(; name=name, kw...) +function _Sym(::Type{T}, name::Symbol; kwargs...) where {T} + impl = Sym(name) + BasicSymbolic{T}(; impl, kwargs...) end function Term{T}(f, args; kw...) where T From d3dcddeb78d35bfeaecc1ee5d593558f39c92da5 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 19:39:16 -0700 Subject: [PATCH 011/140] Update `Term` constructor --- src/types.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/types.jl b/src/types.jl index 417ff1470..2c7132c6c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -293,16 +293,15 @@ function _Sym(::Type{T}, name::Symbol; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end -function Term{T}(f, args; kw...) where T - if eltype(args) !== Any - args = convert(Vector{Any}, args) +function _Term(::Type{T}, f, args; kwargs...) where {T} + if eltype(args) !== BasicSymbolic + args = convert(Vector{BasicSymbolic}, args) end - - Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) + impl = Term(f, args) + BasicSymbolic{T}(; impl, kwargs...) end - -function Term(f, args; metadata=NO_METADATA) - Term{_promote_symtype(f, args)}(f, args, metadata=metadata) +function _Term(f, args; kwargs...) + _Term(_promote_symtype(f, args), f, args; kwargs...) end function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T From ad6dff7f43843691c0d2944ca1403c8603e7db8f Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 19:40:19 -0700 Subject: [PATCH 012/140] Update `_promote_symtype` --- src/types.jl | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/types.jl b/src/types.jl index 2c7132c6c..295ba533a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -915,18 +915,16 @@ end function _promote_symtype(f, args) if issym(f) promote_symtype(f, map(symtype, args)...) + elseif length(args) == 0 + promote_symtype(f) + elseif length(args) == 1 + promote_symtype(f, symtype(args[1])) + elseif length(args) == 2 + promote_symtype(f, symtype(args[1]), symtype(args[2])) + elseif isassociative(f) + mapfoldl(symtype, (x, y) -> promote_symtype(f, x, y), args) else - if length(args) == 0 - promote_symtype(f) - elseif length(args) == 1 - promote_symtype(f, symtype(args[1])) - elseif length(args) == 2 - promote_symtype(f, symtype(args[1]), symtype(args[2])) - elseif isassociative(f) - mapfoldl(symtype, (x,y) -> promote_symtype(f, x, y), args) - else - promote_symtype(f, map(symtype, args)...) - end + promote_symtype(f, map(symtype, args)...) end end From 94f7ef1a0f11110898e272b9d212c6923523b9c3 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 22:26:38 -0700 Subject: [PATCH 013/140] Update `Add` constructor --- src/types.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/types.jl b/src/types.jl index 295ba533a..982f8e54f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -304,20 +304,20 @@ function _Term(f, args; kwargs...) _Term(_promote_symtype(f, args), f, args; kwargs...) end -function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T +function _Add(::Type{T}, coeff, dict; kwargs...) where {T} if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 - k,v = first(dict) + k, v = first(dict) if _isone(v) return k else coeff, dict = makemul(v, k) - return Mul(T, coeff, dict) + return _Mul(T, coeff, dict) end end - - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + impl = Add(; coeff, dict) + BasicSymbolic{T}(; impl, kwargs...) end function Mul(T, a, b; metadata=NO_METADATA, kw...) From f97813a152bfa3fa488dbe25e0f5a7f550b5dbc5 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 18 Jun 2024 22:32:36 -0700 Subject: [PATCH 014/140] Update `Mul` constructor [skip ci] --- src/types.jl | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/types.jl b/src/types.jl index 982f8e54f..24a180dc5 100644 --- a/src/types.jl +++ b/src/types.jl @@ -320,20 +320,18 @@ function _Add(::Type{T}, coeff, dict; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end -function Mul(T, a, b; metadata=NO_METADATA, kw...) - isempty(b) && return a - if _isone(a) && length(b) == 1 - pair = first(b) +function _Mul(::Type{T}, coeff, dict; kwargs...) where {T} + isempty(dict) && return coeff + if _isone(coeff) && length(dict) == 1 + pair = first(dict) if _isone(last(pair)) # first value return first(pair) else return unstable_pow(first(pair), last(pair)) end - else - coeff = a - dict = b - Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) end + impl = Mul(; coeff, dict) + BasicSymbolic{T}(; impl, kwargs...) end const Rat = Union{Rational, Integer} From 42da470c8e1ef9f7fcf523eb10831764d8154eec Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 14:47:42 -0700 Subject: [PATCH 015/140] Remove default value for `Add.dict` --- src/types.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index 24a180dc5..d9764caf5 100644 --- a/src/types.jl +++ b/src/types.jl @@ -6,9 +6,6 @@ const Metadata = Union{Nothing, Base.ImmutableDict{DataType, Any}} const NO_METADATA = nothing const EMPTY_HASH = UInt(0) -sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) - - @adt BasicSymbolicImpl begin struct Sym name::Symbol @@ -19,7 +16,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...) end struct Add coeff::BasicSymbolic - dict::Dict{BasicSymbolic, Any} = sdict() + dict::Dict{BasicSymbolic, Any} arguments::Vector{BasicSymbolic} = BasicSymbolic[] issorted::RefValue{Bool} = Ref(false) end From 44af526b962ef091827176be5ba3f8948c3bb764 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 15:15:27 -0700 Subject: [PATCH 016/140] Create `arguments` for `Div` and `Pow` at construction step --- src/types.jl | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/types.jl b/src/types.jl index d9764caf5..a0cc2afef 100644 --- a/src/types.jl +++ b/src/types.jl @@ -30,12 +30,12 @@ const EMPTY_HASH = UInt(0) num::BasicSymbolic den::BasicSymbolic simplified::RefValue{Bool} = Ref(false) - arguments::Vector{BasicSymbolic} = BasicSymbolic[] + arguments::Vector{BasicSymbolic} = [num, den] end struct Pow base::BasicSymbolic exp::BasicSymbolic - arguments::Vector{BasicSymbolic} = BasicSymbolic[] + arguments::Vector{BasicSymbolic} = [base, exp] end struct Const val::Any @@ -135,8 +135,8 @@ function unsorted_arguments(x::BasicSymbolic) Term => return x.arguments Add => @goto ADDMUL Mul => @goto ADDMUL - Div => @goto DIV - Pow => @goto POW + Div => @goto DIVPOW + Pow => @goto DIVPOW Sym => error_sym() Const => error_const() _ => error_on_type() @@ -162,20 +162,8 @@ function unsorted_arguments(x::BasicSymbolic) end return args - @label DIV + @label DIVPOW args = x.impl.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.impl.num) - push!(args, x.impl.den) - return args - - @label POW - args = x.impl.arguments - isempty(args) || return args - sizehint!(args, 2) - push!(args, x.impl.base) - push!(args, x.impl.exp) return args end From 4bd021c682330d9b15dc12ab1226df58dcf1123b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 15:43:20 -0700 Subject: [PATCH 017/140] Remove default arguments for `Term` struct fields --- src/types.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index a0cc2afef..9b36b5292 100644 --- a/src/types.jl +++ b/src/types.jl @@ -11,8 +11,8 @@ const EMPTY_HASH = UInt(0) name::Symbol end struct Term - f::Any = identity - arguments::Vector{BasicSymbolic} = BasicSymbolic[] + f::Any + arguments::Vector{BasicSymbolic} end struct Add coeff::BasicSymbolic From 552e7f97fa2c3485b319b888b85a31b3aafbe42c Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 16:04:06 -0700 Subject: [PATCH 018/140] Add tests for fundamental types and functions --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index ad533ae15..cc8acf554 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,7 @@ include("utils.jl") if haskey(ENV, "SU_BENCHMARK_ONLY") include("benchmark.jl") else + include("types.jl") include("basics.jl") include("order.jl") include("polyform.jl") From 12d9c71324defe0e2b27ea71cfbf8e115d9e3b18 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 16:10:12 -0700 Subject: [PATCH 019/140] Test Expronicon generated constructors [skip ci] --- test/types.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 test/types.jl diff --git a/test/types.jl b/test/types.jl new file mode 100644 index 000000000..140d18970 --- /dev/null +++ b/test/types.jl @@ -0,0 +1,26 @@ +using SymbolicUtils: BasicSymbolic + +@testset "Expronicon generated constructors" begin + s1 = Sym(:abc) + s2 = Sym(name = :def) + name = :ghi + s3 = Sym(; name) + bs1 = BasicSymbolic{Float64}(impl = s1) + impl = s2 + bs2 = BasicSymbolic{Int64}(; impl) + @testset "Sym" begin + @test s1.name == :abc + @test typeof(s1) == BasicSymbolicImpl + @test s2.name == :def + @test s3.name == :ghi + end + @testset "Term" begin + t1 = Term(sin, [bs1]) + @test t1.f == sin + @test t1.arguments == [bs1] + @test typeof(t1.arguments) == Vector{BasicSymbolic} + @test_throws MethodError Term(sin, [s1]) + @test_throws MethodError Term(sin, [1]) + @test_throws MethodError Term(sin, [2.0]) + end +end \ No newline at end of file From 26cee9f659aa16dbddd7b4368edca2e8a3f64a99 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 17:22:22 -0700 Subject: [PATCH 020/140] Add tests for `Sym` --- test/types.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/types.jl b/test/types.jl index 140d18970..48288a1d8 100644 --- a/test/types.jl +++ b/test/types.jl @@ -9,7 +9,9 @@ using SymbolicUtils: BasicSymbolic impl = s2 bs2 = BasicSymbolic{Int64}(; impl) @testset "Sym" begin + @test_nowarn Sym(Symbol("")) @test s1.name == :abc + @test typeof(s2.name) == Symbol @test typeof(s1) == BasicSymbolicImpl @test s2.name == :def @test s3.name == :ghi @@ -23,4 +25,4 @@ using SymbolicUtils: BasicSymbolic @test_throws MethodError Term(sin, [1]) @test_throws MethodError Term(sin, [2.0]) end -end \ No newline at end of file +end From d2d21b9d470aa913a2cf05dc9997c3d3fdd01c3f Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 19 Jun 2024 17:47:26 -0700 Subject: [PATCH 021/140] Test Expronicon generated constructors for `Div` [skip ci] --- test/types.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/types.jl b/test/types.jl index 48288a1d8..563a1c052 100644 --- a/test/types.jl +++ b/test/types.jl @@ -25,4 +25,24 @@ using SymbolicUtils: BasicSymbolic @test_throws MethodError Term(sin, [1]) @test_throws MethodError Term(sin, [2.0]) end + @testset "Div" begin + d1 = Div(num = bs1, den = bs2) + @test typeof(d1.num) == BasicSymbolic{Float64} + @test typeof(d1.den) == BasicSymbolic{Int64} + @test d1.num == bs1 + @test d1.den == bs2 + @test typeof(d1.simplified) == Base.RefValue{Bool} + @test isassigned(d1.simplified) + @test !d1.simplified[] + @test typeof(d1.arguments) == Vector{BasicSymbolic} + @test d1.arguments == [bs1, bs2] + num = bs1 + den = bs2 + d2 = Div(; num, den) + @test d2.num == bs1 + @test d2.den == bs2 + @test_throws MethodError Div(num = s1, den = bs2) + @test_throws MethodError Div(num = bs1, den = s2) + @test_throws MethodError Div(num = s1, den = s2) + end end From 559c73253718af7464267e8bf03c84036147545a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Jun 2024 21:50:44 -0700 Subject: [PATCH 022/140] Test Expronicon `Pow` --- test/types.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/types.jl b/test/types.jl index 563a1c052..76a67cdb7 100644 --- a/test/types.jl +++ b/test/types.jl @@ -45,4 +45,21 @@ using SymbolicUtils: BasicSymbolic @test_throws MethodError Div(num = bs1, den = s2) @test_throws MethodError Div(num = s1, den = s2) end + @testset "Pow" begin + p1 = Pow(base = bs1, exp = bs2) + @test typeof(p1.base) == BasicSymbolic{Float64} + @test typeof(p1.exp) == BasicSymbolic{Int64} + @test p1.base == bs1 + @test p1.exp == bs2 + @test typeof(p1.arguments) == Vector{BasicSymbolic} + @test p1.arguments == [bs1, bs2] + base = bs1 + exp = bs2 + p2 = Pow(; base, exp) + @test p2.base == bs1 + @test p2.exp == bs2 + @test_throws MethodError Pow(base = s1, exp = bs2) + @test_throws MethodError Pow(base = bs1, exp = s2) + @test_throws MethodError Pow(base = s1, exp = s2) + end end From 3d808ab1daa080e253a7283756efd99d692766d9 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Jun 2024 22:19:32 -0700 Subject: [PATCH 023/140] Test Expronicon `Const` [skip ci] --- test/types.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/types.jl b/test/types.jl index 76a67cdb7..f602f6407 100644 --- a/test/types.jl +++ b/test/types.jl @@ -62,4 +62,20 @@ using SymbolicUtils: BasicSymbolic @test_throws MethodError Pow(base = bs1, exp = s2) @test_throws MethodError Pow(base = s1, exp = s2) end + c1 = Const(1) + bc1 = BasicSymbolic{Int}(impl = c1) + c2 = Const(val = 3.14) + bc2 = BasicSymbolic{Float64}(impl = c2) + @testset "Const" begin + @test typeof(c1.val) == Int + @test c1.val == 1 + @test typeof(c2.val) == Float64 + @test c2.val == 3.14 + c3 = Const(big"123456789012345678901234567890") + @test typeof(c3.val) == BigInt + @test c3.val == big"123456789012345678901234567890" + c4 = Const(big"1.23456789012345678901") + @test typeof(c4.val) == BigFloat + @test c4.val == big"1.23456789012345678901" + end end From 10bae2af595a124b1258c3904a1bca97ca24c93b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Jun 2024 22:40:40 -0700 Subject: [PATCH 024/140] Add Expronicon type check tests --- test/types.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/types.jl b/test/types.jl index f602f6407..36de1b764 100644 --- a/test/types.jl +++ b/test/types.jl @@ -9,6 +9,7 @@ using SymbolicUtils: BasicSymbolic impl = s2 bs2 = BasicSymbolic{Int64}(; impl) @testset "Sym" begin + @test typeof(s1) == BasicSymbolicImpl @test_nowarn Sym(Symbol("")) @test s1.name == :abc @test typeof(s2.name) == Symbol @@ -18,6 +19,7 @@ using SymbolicUtils: BasicSymbolic end @testset "Term" begin t1 = Term(sin, [bs1]) + @test typeof(t1) == BasicSymbolicImpl @test t1.f == sin @test t1.arguments == [bs1] @test typeof(t1.arguments) == Vector{BasicSymbolic} @@ -27,6 +29,7 @@ using SymbolicUtils: BasicSymbolic end @testset "Div" begin d1 = Div(num = bs1, den = bs2) + @test typeof(d1) == BasicSymbolicImpl @test typeof(d1.num) == BasicSymbolic{Float64} @test typeof(d1.den) == BasicSymbolic{Int64} @test d1.num == bs1 @@ -47,6 +50,7 @@ using SymbolicUtils: BasicSymbolic end @testset "Pow" begin p1 = Pow(base = bs1, exp = bs2) + @test typeof(p1) == BasicSymbolicImpl @test typeof(p1.base) == BasicSymbolic{Float64} @test typeof(p1.exp) == BasicSymbolic{Int64} @test p1.base == bs1 @@ -67,6 +71,7 @@ using SymbolicUtils: BasicSymbolic c2 = Const(val = 3.14) bc2 = BasicSymbolic{Float64}(impl = c2) @testset "Const" begin + @test typeof(c1) == BasicSymbolicImpl @test typeof(c1.val) == Int @test c1.val == 1 @test typeof(c2.val) == Float64 From f52bbfe867bae4cf1e0a31fb7289dee2ed055113 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Jun 2024 23:22:29 -0700 Subject: [PATCH 025/140] Test Expronicon `Add` & `Mul` --- test/types.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/types.jl b/test/types.jl index 36de1b764..0c91c920b 100644 --- a/test/types.jl +++ b/test/types.jl @@ -83,4 +83,30 @@ using SymbolicUtils: BasicSymbolic @test typeof(c4.val) == BigFloat @test c4.val == big"1.23456789012345678901" end + coeff = bc1 + dict = Dict(bs1 => 3, bs2 => 5) + @testset "Add" begin + a1 = Add(; coeff, dict) + @test typeof(a1) == BasicSymbolicImpl + @test a1.coeff isa BasicSymbolic + @test isequal(a1.coeff, bc1) + @test typeof(a1.dict) == Dict{BasicSymbolic, Any} + @test a1.dict == dict + @test typeof(a1.arguments) == Vector{BasicSymbolic} + @test isempty(a1.arguments) + @test typeof(a1.issorted) == Base.RefValue{Bool} + @test !a1.issorted[] + end + @testset "Mul" begin + m1 = Mul(; coeff, dict) + @test typeof(m1) == BasicSymbolicImpl + @test m1.coeff isa BasicSymbolic + @test isequal(m1.coeff, bc1) + @test typeof(m1.dict) == Dict{BasicSymbolic, Any} + @test m1.dict == dict + @test typeof(m1.arguments) == Vector{BasicSymbolic} + @test isempty(m1.arguments) + @test typeof(m1.issorted) == Base.RefValue{Bool} + @test !m1.issorted[] + end end From 6906ced7165a910b617afd02edbe695e4de183d7 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 20 Jun 2024 23:29:47 -0700 Subject: [PATCH 026/140] Test `BasicSymbolic` `@kwdef` keyword-based constructor [skip ci] --- test/types.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/types.jl b/test/types.jl index 0c91c920b..3eec246dd 100644 --- a/test/types.jl +++ b/test/types.jl @@ -109,4 +109,13 @@ using SymbolicUtils: BasicSymbolic @test typeof(m1.issorted) == Base.RefValue{Bool} @test !m1.issorted[] end + @testset "BasicSymbolic" begin + @test typeof(bs1) == BasicSymbolic{Float64} + @test bs1 isa BasicSymbolic + @test bs1 isa SymbolicUtils.Symbolic + @test bs1.metadata isa SymbolicUtils.Metadata + @test bs1.metadata == SymbolicUtils.NO_METADATA + @test typeof(bs1.hash) == Base.RefValue{UInt} + @test bs1.hash[] == SymbolicUtils.EMPTY_HASH + end end From a771dab2ee953894fcd9da50d719b1a0cf0fb6db Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 27 Jun 2024 22:18:10 -0700 Subject: [PATCH 027/140] Add `_iszero(x::BasicSymbolic)` --- src/types.jl | 7 +++++++ test/types.jl | 19 ++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index fe062aa0f..d80aa3c92 100644 --- a/src/types.jl +++ b/src/types.jl @@ -323,6 +323,13 @@ function _Mul(::Type{T}, coeff, dict; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end +function _iszero(x::BasicSymbolic) + @match x.impl begin + Const(_...) => iszero(x.impl.val) + _ => false + end +end + const Rat = Union{Rational, Integer} function ratcoeff(x) diff --git a/test/types.jl b/test/types.jl index 3eec246dd..0968b7921 100644 --- a/test/types.jl +++ b/test/types.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: BasicSymbolic +using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @testset "Expronicon generated constructors" begin s1 = Sym(:abc) @@ -119,3 +119,20 @@ using SymbolicUtils: BasicSymbolic @test bs1.hash[] == SymbolicUtils.EMPTY_HASH end end + +@testset "BasicSymbolic iszero" begin + c1 = _Const(0) + @test SymbolicUtils._iszero(c1) + c2 = _Const(1) + @test !SymbolicUtils._iszero(c2) + c3 = _Const(0.0) + @test SymbolicUtils._iszero(c3) + c4 = _Const(0.00000000000000000000000001) + @test !SymbolicUtils._iszero(c4) + c5 = _Const(big"326264532521352634435352152") + @test !SymbolicUtils._iszero(c5) + c6 = _Const(big"0.314654523452") + @test !SymbolicUtils._iszero(c6) + s = _Sym(Real, :y) + @test !SymbolicUtils._iszero(s) +end From 2e682536105bde4a09ca0a4eca5a494f6febde07 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 27 Jun 2024 22:18:59 -0700 Subject: [PATCH 028/140] Add a `Const` custom constructor --- src/types.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/types.jl b/src/types.jl index d80aa3c92..8fdfb6916 100644 --- a/src/types.jl +++ b/src/types.jl @@ -293,6 +293,11 @@ function _Term(f, args; kwargs...) _Term(_promote_symtype(f, args), f, args; kwargs...) end +function _Const(val::T; kwargs...) where {T} + impl = Const(val) + BasicSymbolic{T}(; impl, kwargs...) +end + function _Add(::Type{T}, coeff, dict; kwargs...) where {T} if isempty(dict) return coeff From 1d50a8d98ca247e41fff7b9ad6ca3624673fa533 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 27 Jun 2024 22:32:22 -0700 Subject: [PATCH 029/140] Test custom constructors for `Sym`, `Term` & `Const` --- test/types.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/test/types.jl b/test/types.jl index 0968b7921..899d598df 100644 --- a/test/types.jl +++ b/test/types.jl @@ -41,7 +41,7 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @test d1.arguments == [bs1, bs2] num = bs1 den = bs2 - d2 = Div(; num, den) + d2 = Div(; num, den) @test d2.num == bs1 @test d2.den == bs2 @test_throws MethodError Div(num = s1, den = bs2) @@ -59,7 +59,7 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @test p1.arguments == [bs1, bs2] base = bs1 exp = bs2 - p2 = Pow(; base, exp) + p2 = Pow(; base, exp) @test p2.base == bs1 @test p2.exp == bs2 @test_throws MethodError Pow(base = s1, exp = bs2) @@ -120,6 +120,48 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add end end +@testset "Custom constructors" begin + @testset "Sym" begin + s1 = _Sym(Int64, :x) + s2 = _Sym(Float64, :y) + @test typeof(s1) == BasicSymbolic{Int64} + @test s1.metadata == SymbolicUtils.NO_METADATA + @test s1.hash[] == SymbolicUtils.EMPTY_HASH + @test s1.impl.name == :x + @test typeof(s2) == BasicSymbolic{Float64} + @test s2.metadata == SymbolicUtils.NO_METADATA + @test s2.hash[] == SymbolicUtils.EMPTY_HASH + @test s2.impl.name == :y + end + @testset "Term" begin + s1 = _Sym(Float64, :x) + s2 = _Sym(Float64, :y) + t = _Term(Float64, mod, [s1, s2]) + @test typeof(t) == BasicSymbolic{Float64} + @test t.metadata == SymbolicUtils.NO_METADATA + @test t.hash[] == SymbolicUtils.EMPTY_HASH + @test t.impl.f == mod + @test t.impl.arguments == [s1, s2] + end + @testset "Const" begin + c1 = _Const(1.0) + @test typeof(c1) == BasicSymbolic{Float64} + @test c1.metadata == SymbolicUtils.NO_METADATA + @test c1.hash[] == SymbolicUtils.EMPTY_HASH + @test c1.impl.val == 1.0 + c2 = _Const(big"123456789012345678901234567890") + @test typeof(c2) == BasicSymbolic{BigInt} + @test c2.metadata == SymbolicUtils.NO_METADATA + @test c2.hash[] == SymbolicUtils.EMPTY_HASH + @test c2.impl.val == big"123456789012345678901234567890" + c3 = _Const(big"1.23456789012345678901") + @test typeof(c3) == BasicSymbolic{BigFloat} + @test c3.metadata == SymbolicUtils.NO_METADATA + @test c3.hash[] == SymbolicUtils.EMPTY_HASH + @test c3.impl.val == big"1.23456789012345678901" + end +end + @testset "BasicSymbolic iszero" begin c1 = _Const(0) @test SymbolicUtils._iszero(c1) From e645ba695dca50571c7d761bf3595bf8de0092a7 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 27 Jun 2024 22:35:01 -0700 Subject: [PATCH 030/140] Correct MLStyle `@match` in `exprtype` [skip ci] --- src/types.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/types.jl b/src/types.jl index 8fdfb6916..a21029889 100644 --- a/src/types.jl +++ b/src/types.jl @@ -54,14 +54,14 @@ end function exprtype(x::BasicSymbolic) @match x.impl begin - Term => TERM - Add => ADD - Mul => MUL - Div => DIV - Pow => POW - Sym => SYM - Const => CONST - _ => error_on_type() + Sym(_...) => SYM + Term(_...) => TERM + Add(_...) => ADD + Mul(_...) => MUL + Div(_...) => DIV + Pow(_...) => POW + Const(_...) => CONST + _ => error_on_type() end end From 5c8dd149e1315e6b17dd176b891268e1f57d9a13 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 10:50:32 -0700 Subject: [PATCH 031/140] Correct MLStyle `@match` in `operation` --- src/types.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/types.jl b/src/types.jl index a21029889..31aa3d1a2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -93,14 +93,14 @@ symtype(x::Number) = typeof(x) # We're returning a function pointer @inline function operation(x::BasicSymbolic) @match x.impl begin - Term => x.f - Add => (+) - Mul => (*) - Div => (/) - Pow => (^) - Sym => error_sym() - Const => error_const() - _ => error_on_type() + Term(_...) => x.impl.f + Add(_...) => (+) + Mul(_...) => (*) + Div(_...) => (/) + Pow(_...) => (^) + Sym(_...) => error_sym() + Const(_...) => error_const() + _ => error_on_type() end end From 17721733eb3b2c2c3af6a3a3f8541572b36df62a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 10:58:06 -0700 Subject: [PATCH 032/140] Correct MLStyle `@match` in `sorted_arguments` --- src/types.jl | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/types.jl b/src/types.jl index 31aa3d1a2..6a15d855e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -108,22 +108,23 @@ end function sorted_arguments(x::BasicSymbolic) args = arguments(x) - @match x.impl begin - Add => @goto ADD - Mul => @goto MUL - _ => return args + impl = x.impl + @match impl begin + Add(_...) => @goto ADD + Mul(_...) => @goto MUL + _ => return args end @label MUL - if !x.issorted[] - sort!(args, by=get_degrees) - x.issorted[] = true + if !impl.issorted[] + sort!(args, by = get_degrees) + impl.issorted[] = true end return args @label ADD - if !x.issorted[] - sort!(args, lt = monomial_lt, by=get_degrees) - x.issorted[] = true + if !impl.issorted[] + sort!(args, lt = monomial_lt, by = get_degrees) + impl.issorted[] = true end return args end From 9e4049afd61b3c25341c8d6b56bef363fadee09a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 11:29:38 -0700 Subject: [PATCH 033/140] Fix MLStyle `@match` in `arguments` --- src/types.jl | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/types.jl b/src/types.jl index 6a15d855e..af62c18e7 100644 --- a/src/types.jl +++ b/src/types.jl @@ -136,39 +136,39 @@ sorted_children(x::BasicSymbolic) = sorted_arguments(x) @deprecate unsorted_arguments(x) arguments(x) function arguments(x::BasicSymbolic) - @match x.impl begin - Term => return x.arguments - Add => @goto ADDMUL - Mul => @goto ADDMUL - Div => @goto DIVPOW - Pow => @goto DIVPOW - Sym => error_sym() - Const => error_const() - _ => error_on_type() + impl = x.impl + @match impl begin + Term(_...) => return impl.arguments + Add(_...) => @goto ADDMUL + Mul(_...) => @goto ADDMUL + Div(_...) => @goto DIVPOW + Pow(_...) => @goto DIVPOW + Sym(_...) => error_sym() + Const(_...) => error_const() + _ => error_on_type() end @label ADDMUL E = exprtype(x) - args = x.impl.arguments + args = impl.arguments isempty(args) || return args - siz = length(x.impl.dict) - idcoeff = E === ADD ? iszero(x.impl.coeff) : isone(x.impl.coeff) + siz = length(impl.dict) + idcoeff = E === ADD ? iszero(impl.coeff) : isone(impl.coeff) sizehint!(args, idcoeff ? siz : siz + 1) - idcoeff || push!(args, x.impl.coeff) + idcoeff || push!(args, impl.coeff) if isadd(x) - for (k, v) in x.impl.dict - push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v])) + for (k, v) in impl.dict + push!(args, applicable(*, k, v) ? k * v : maketerm(k, *, [k, v])) end else # MUL - for (k, v) in x.impl.dict + for (k, v) in impl.dict push!(args, unstable_pow(k, v)) end end return args @label DIVPOW - args = x.impl.arguments + args = impl.arguments return args end From dda1206e51b7d394670fc8058ecffd602529b3ea Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 11:51:50 -0700 Subject: [PATCH 034/140] Rewrite `is...` functions --- src/types.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/src/types.jl b/src/types.jl index af62c18e7..84ac75e36 100644 --- a/src/types.jl +++ b/src/types.jl @@ -175,13 +175,54 @@ end isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false -issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) -isterm(x) = isa_SymType(Val(:Term), x) -ismul(x) = isa_SymType(Val(:Mul), x) -isadd(x) = isa_SymType(Val(:Add), x) -ispow(x) = isa_SymType(Val(:Pow), x) -isdiv(x) = isa_SymType(Val(:Div), x) +function issym(x) + isa(x, BasicSymbolic) && @match x.impl begin + Sym(_...) => true + _ => false + end +end + +function isterm(x) + isa(x, BasicSymbolic) && @match x.impl begin + Term(_...) => true + _ => false + end +end + +function isadd(x) + isa(x, BasicSymbolic) && @match x.impl begin + Add(_...) => true + _ => false + end +end + +function ismul(x) + isa(x, BasicSymbolic) && @match x.impl begin + Mul(_...) => true + _ => false + end +end + +function ispow(x) + isa(x, BasicSymbolic) && @match x.impl begin + Pow(_...) => true + _ => false + end +end + +function isdiv(x) + isa(x, BasicSymbolic) && @match x.impl begin + Div(_...) => true + _ => false + end +end + +function isconst(x) + isa(x, BasicSymbolic) && @match x.impl begin + Const(_...) => true + _ => false + end +end ### ### Base interface From 385ad06799b8d70937a0b2b116eb858f5e70adcf Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 11:53:43 -0700 Subject: [PATCH 035/140] Adapt `nameof` to new struct structure --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 84ac75e36..ce85b1c01 100644 --- a/src/types.jl +++ b/src/types.jl @@ -273,7 +273,7 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +Base.nameof(s::BasicSymbolic) = issym(s) ? s.impl.name : error("None Sym BasicSymbolic doesn't have a name") ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) From b2e3a0db5bc2ed0468cc473500313fb4849a782a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:11:51 -0700 Subject: [PATCH 036/140] Modify `Div` custom constructors --- src/types.jl | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/types.jl b/src/types.jl index ce85b1c01..1be6ffd90 100644 --- a/src/types.jl +++ b/src/types.jl @@ -404,21 +404,19 @@ function maybe_intcoeff(x) end end -function Div{T}(n, d, simplified=false; metadata=nothing) where {T} - if T<:Number && !(T<:SafeReal) - n, d = quick_cancel(n, d) - end - _iszero(n) && return zero(typeof(n)) - _isone(d) && return n - - if isdiv(n) && isdiv(d) - return Div{T}(n.num * d.den, n.den * d.num) - elseif isdiv(n) - return Div{T}(n.num, n.den * d) +function _Div(::Type{T}, num, den; kwargs...) where {T} + if T <: Number && !(T <: SafeReal) + num, den = quick_cancel(num, den) + end + _iszero(num) && return zero(typeof(num)) + _isone(den) && return den + if isdiv(num) && isdiv(den) + return _Div(T, num.impl.num * den.impl.den, num.impl.den * den.impl.num) + elseif isdiv(num) + return _Div(T, num.impl.num, num.impl.den * den) elseif isdiv(d) - return Div{T}(n * d.den, d.num) + return _Div(T, num * den.impl.den, den.impl.num) end - d isa Number && _isone(-d) && return -1 * n n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify @@ -438,11 +436,11 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T} end end - Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + impl = Div(; num, den) + BasicSymbolic{T}(; impl, kwargs...) end - -function Div(n,d, simplified=false; kw...) - Div{promote_symtype((/), symtype(n), symtype(d))}(n, d, simplified; kw...) +function _Div(num, den; kwargs...) + Div{promote_symtype((/), symtype(num), symtype(den))}(num, den; kwargs...) end @inline function numerators(x) From 911789aa8ab092168f7e35e36dcfa55c6dca11b0 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:15:18 -0700 Subject: [PATCH 037/140] Adapt `numerators` & `denominators` to new struct structure --- src/types.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index 1be6ffd90..713e617ee 100644 --- a/src/types.jl +++ b/src/types.jl @@ -444,11 +444,11 @@ function _Div(num, den; kwargs...) end @inline function numerators(x) - isdiv(x) && return numerators(x.num) + isdiv(x) && return numerators(x.impl.num) iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] end -@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] +@inline denominators(x) = isdiv(x) ? numerators(x.impl.den) : Any[1] function (::Type{<:Pow{T}})(a, b; metadata=NO_METADATA) where {T} _iszero(b) && return 1 From 776942b9c73d5977e7138c998ed03d01aa17df2b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:18:41 -0700 Subject: [PATCH 038/140] Rewrite `Pow` custom constructors --- src/types.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/types.jl b/src/types.jl index 713e617ee..fa065b65c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -450,14 +450,14 @@ end @inline denominators(x) = isdiv(x) ? numerators(x.impl.den) : Any[1] -function (::Type{<:Pow{T}})(a, b; metadata=NO_METADATA) where {T} - _iszero(b) && return 1 - _isone(b) && return a - Pow{T}(; base=a, exp=b, arguments=[], metadata) +function _Pow(::Type{T}, base, exp; kwargs...) where {T} + _iszero(exp) && return 1 + _isone(exp) && return a + impl = (; base, exp) + BasicSymbolic{T}(; impl, kwargs...) end - -function Pow(a, b; metadata=NO_METADATA) - Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata) +function _Pow(base, exp; kwargs...) + Pow{promote_symtype(^, symtype(base), symtype(b))}(makepow(base, exp)..., kwargs...) end function toterm(t::BasicSymbolic{T}) where T From fdd345e6b0efbc601f616e82529d1f53238839a6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:22:59 -0700 Subject: [PATCH 039/140] Rewrite `toterm` --- src/types.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/types.jl b/src/types.jl index fa065b65c..09e2354f5 100644 --- a/src/types.jl +++ b/src/types.jl @@ -460,21 +460,22 @@ function _Pow(base, exp; kwargs...) Pow{promote_symtype(^, symtype(base), symtype(b))}(makepow(base, exp)..., kwargs...) end -function toterm(t::BasicSymbolic{T}) where T +function toterm(t::BasicSymbolic{T}) where {T} E = exprtype(t) if E === SYM || E === TERM return t elseif E === ADD || E === MUL - args = Any[] - push!(args, t.coeff) - for (k, coeff) in t.dict - push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), Any[coeff, k])) + args = BasicSymbolic[] + push!(args, t.impl.coeff) + for (k, coeff) in t.impl.dict + push!( + args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) end - Term{T}(operation(t), args) + _Term(T, operation(t), args) elseif E === DIV - Term{T}(/, Any[t.num, t.den]) + _Term(T, /, [t.impl.num, t.impl.den]) elseif E === POW - Term{T}(^, [t.base, t.exp]) + _Term(T, ^, [t.impl.base, t.impl.exp]) else error_on_type() end From 21161cc882d7aa59039341e1ac5bd811976cea72 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:29:21 -0700 Subject: [PATCH 040/140] Rewrite `makeadd`, `makemul`, `makepow` --- src/types.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/types.jl b/src/types.jl index 09e2354f5..c37fe7642 100644 --- a/src/types.jl +++ b/src/types.jl @@ -481,18 +481,18 @@ function toterm(t::BasicSymbolic{T}) where {T} end end -""" - makeadd(sign, coeff::Number, xs...) +"""" +$(SIGNATURES) Any Muls inside an Add should always have a coeff of 1 and the key (in Add) should instead be used to store the actual coefficient """ function makeadd(sign, coeff, xs...) - d = sdict() + d = Dict{BasicSymbolic, Any}() for x in xs if isadd(x) - coeff += x.coeff - _merge!(+, d, x.dict, filter=_iszero) + coeff += x.impl.coeff + _merge!(+, d, x.impl.dict, filter = _iszero) continue end if x isa Number @@ -500,8 +500,8 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = Mul(symtype(x), 1, x.dict) - v = sign * x.coeff + get(d, k, 0) + k = _Mul(symtype(x), 1, x.dict) + v = sign * x.impl.coeff + get(d, k, 0) else k = x v = sign + get(d, x, 0) @@ -515,15 +515,15 @@ function makeadd(sign, coeff, xs...) coeff, d end -function makemul(coeff, xs...; d=sdict()) +function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) for x in xs - if ispow(x) && x.exp isa Number - d[x.base] = x.exp + get(d, x.base, 0) + if ispow(x) && x.impl.exp isa Number + d[x.impl.base] = x.impl.exp + get(d, x.impl.base, 0) elseif x isa Number coeff *= x elseif ismul(x) - coeff *= x.coeff - _merge!(+, d, x.dict, filter=_iszero) + coeff *= x.impl.coeff + _merge!(+, d, x.impl.dict, filter = _iszero) else v = 1 + get(d, x, 0) if _iszero(v) @@ -533,19 +533,19 @@ function makemul(coeff, xs...; d=sdict()) end end end - (coeff, d) + coeff, d end -unstable_pow(a, b) = a isa Integer && b isa Integer ? (a//1) ^ b : a ^ b +unstable_pow(a, b) = a isa Integer && b isa Integer ? (a // 1)^b : a^b function makepow(a, b) base = a exp = b if ispow(a) - base = a.base - exp = a.exp * b + base = a.impl.base + exp = a.impl.exp * b end - return (base, exp) + base, exp end function term(f, args...; type = nothing) From 7e04640cca06623690fc0fc5e4a0d2155596aa47 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:31:24 -0700 Subject: [PATCH 041/140] Rewrite `term` --- src/types.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/types.jl b/src/types.jl index c37fe7642..f43bc2ec4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -548,13 +548,11 @@ function makepow(a, b) base, exp end -function term(f, args...; type = nothing) - if type === nothing +function term(f, args...; T = nothing) + if T === nothing T = _promote_symtype(f, args) - else - T = type end - Term{T}(f, Any[args...]) + _Term(T, f, [args...]) end """ From 1803d3e406d96ff9abea2b0a1fc88cf922151c4e Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:33:25 -0700 Subject: [PATCH 042/140] Modify `unflatten` --- src/types.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index f43bc2ec4..633674312 100644 --- a/src/types.jl +++ b/src/types.jl @@ -556,20 +556,20 @@ function term(f, args...; T = nothing) end """ - unflatten(t::Symbolic{T}) +$(TYPEDSIGNATURES) + Binarizes `Term`s with n-ary operations """ -function unflatten(t::Symbolic{T}) where{T} +function unflatten(t::Symbolic{T}) where {T} if iscall(t) f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops a = arguments(t) - return foldl((x,y) -> Term{T}(f, Any[x, y]), a) + return foldl((x, y) -> _Term(T, f, [x, y]), a) end end return t end - unflatten(t) = t function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) From 209a64d854470fba385f2d3b445524ec6caa8658 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:35:29 -0700 Subject: [PATCH 043/140] Modify `basicsymbolic` --- src/types.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 633674312..ca549ee6e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -586,15 +586,16 @@ function basicsymbolic(f, args, stype, metadata) T = _promote_symtype(f, args) end if T <: LiteralReal - Term{T}(f, args, metadata=metadata) - elseif T <: Number && (f in (+, *) || (f in (/, ^) && length(args) == 2)) && all(x->symtype(x) <: Number, args) + _Term(T, f, args; metadata) + elseif T <: Number && (f in (+, *) || (f in (/, ^) && length(args) == 2)) && + all(x -> symtype(x) <: Number, args) res = f(args...) if res isa Symbolic @set! res.metadata = metadata end return res else - Term{T}(f, args, metadata=metadata) + _Term(T, f, args; metadata) end end From bb42582ac7743360715f3ec8b6dd26590ecd301d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:37:25 -0700 Subject: [PATCH 044/140] Modify `setargs` --- src/types.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index ca549ee6e..cb393b143 100644 --- a/src/types.jl +++ b/src/types.jl @@ -576,7 +576,6 @@ function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metad basicsymbolic(head, args, type, metadata) end - function basicsymbolic(f, args, stype, metadata) if f isa Symbol error("$f must not be a Symbol") @@ -710,7 +709,7 @@ function isnegative(t) end # Term{} -setargs(t, args) = Term{symtype(t)}(operation(t), args) +setargs(t, args) = _Term(symtype(t), operation(t), args) cdrargs(args) = setargs(t, cdr(args)) print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")") From 4e459bbbd8db1e60d3d44f296ec62b0c36f0c2e0 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:39:53 -0700 Subject: [PATCH 045/140] Modify `(f::Symbolic{<:FnType})(args...)` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index cb393b143..7b181f56d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -912,7 +912,7 @@ promote_symtype(f, Ts...) = Any struct FnType{X<:Tuple,Y} end -(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, Any[args...]) +(f::Symbolic{<:FnType})(args...) = _Term(promote_symtype(f, symtype.(args)...), f, [args...]) function (f::Symbolic)(args...) error("Sym $f is not callable. " * From 5a60f5a125ff9f02b389a5656918297de88600c5 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:44:15 -0700 Subject: [PATCH 046/140] Modify `@syms` --- src/types.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index 7b181f56d..74dbfad3c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -996,12 +996,10 @@ macro syms(xs...) defs = map(xs) do x n, t = _name_type(x) T = esc(t) - nt = _name_type(x) - n, t = nt.name, nt.type - :($(esc(n)) = Sym{$T}($(Expr(:quote, n)))) + :($(esc(n)) = _Sym($T, $(Expr(:quote, n)))) end Expr(:block, defs..., - :(tuple($(map(x->esc(_name_type(x).name), xs)...)))) + :(tuple($(map(x -> esc(_name_type(x).name), xs)...)))) end function syms_syntax_error() From 3d487a04aaad3d247b33b650bc84d5ba7e725118 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 12:54:16 -0700 Subject: [PATCH 047/140] Modify `+`, `-`, `*`, `/`, `\`, `^` [skip ci] --- src/types.jl | 104 ++++++++++++++++++++------------------------------- 1 file changed, 41 insertions(+), 63 deletions(-) diff --git a/src/types.jl b/src/types.jl index 74dbfad3c..cd35d4652 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1076,87 +1076,78 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) - !issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata + !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) - return Add(add_t(a,b), - a.coeff + b.coeff, - _merge(+, a.dict, b.dict, filter=_iszero)) + return _Add( + add_t(a, b), a.coeff + b.coeff, _merge(+, a.dict, b.dict, filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return Add(add_t(a,b), a.coeff + coeff, _merge(+, a.dict, dict, filter=_iszero)) + return _Add(add_t(a, b), a.coeff + coeff, _merge(+, a.dict, dict, filter = _iszero)) elseif isadd(b) return b + a end coeff, dict = makeadd(1, 0, a, b) - Add(add_t(a,b), coeff, dict) + _Add(add_t(a, b), coeff, dict) end - function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - Add(add_t(a,b), a + b.coeff, b.dict) + _Add(add_t(a, b), a + b.coeff, b.dict) else - Add(add_t(a,b), makeadd(1, a, b)...) + _Add(add_t(a, b), makeadd(1, a, b)...) end end - +(a::SN, b::Number) = b + a - +(a::SN) = a function -(a::SN) !issafecanon(*, a) && return term(-, a) - isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) : - Add(sub_t(a), makeadd(-1, 0, a)...) + isadd(a) ? _Add(sub_t(a), -a.coeff, mapvalues((_, v) -> -v, a.dict)) : + _Add(sub_t(a), makeadd(-1, 0, a)...) end - function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) - isadd(a) && isadd(b) ? Add(sub_t(a,b), - a.coeff - b.coeff, - _merge(-, a.dict, - b.dict, - filter=_iszero)) : a + (-b) + if isadd(a) && isadd(b) + _Add(sub_t(a, b), a.coeff - b.coeff, _merge(-, a.dict, b.dict, filter = _iszero)) + else + a + (-b) + end end - -(a::Number, b::SN) = a + (-b) -(a::SN, b::Number) = a + (-b) - -mul_t(a,b) = promote_symtype(*, symtype(a), symtype(b)) +mul_t(a, b) = promote_symtype(*, symtype(a), symtype(b)) mul_t(a) = promote_symtype(*, symtype(a)) -*(a::SN) = a - function *(a::SN, b::SN) # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) - Div(a.num * b.num, a.den * b.den) + _Div(a.impl.num * b.impl.num, a.impl.den * b.impl.den) elseif isdiv(a) - Div(a.num * b, a.den) + _Div(a.impl.num * b, a.impl.den) elseif isdiv(b) - Div(a * b.num, b.den) + _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) - Mul(mul_t(a, b), - a.coeff * b.coeff, - _merge(+, a.dict, b.dict, filter=_iszero)) + _Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff, + _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif ismul(a) && ispow(b) if b.exp isa Number - Mul(mul_t(a, b), - a.coeff, _merge(+, a.dict, Base.ImmutableDict(b.base=>b.exp), filter=_iszero)) + _Mul(mul_t(a, b), + a.impl.coeff, + _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), + filter = _iszero)) else - Mul(mul_t(a, b), - a.coeff, _merge(+, a.dict, Base.ImmutableDict(b=>1), filter=_iszero)) + _Mul(mul_t(a, b), a.impl.coeff, + _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) b * a else - Mul(mul_t(a,b), makemul(1, a, b)...) + _Mul(mul_t(a, b), makemul(1, a, b)...) end end - function *(a::Number, b::SN) !issafecanon(*, b) && return term(*, a, b) if iszero(a) @@ -1164,53 +1155,40 @@ function *(a::Number, b::SN) elseif isone(a) b elseif isdiv(b) - Div(a*b.num, b.den) + Div(a * b.impl.num, b.impl.den) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - Add(T, b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) + _Add(T, b.impl.coeff * a, + Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) else - Mul(mul_t(a, b), makemul(a, b)...) + _Mul(mul_t(a, b), makemul(a, b)...) end end - -### -### Div -### - -/(a::Union{SN,Number}, b::SN) = Div(a, b) - *(a::SN, b::Number) = b * a +*(a::SN) = a -\(a::SN, b::Union{Number, SN}) = b / a - -\(a::Number, b::SN) = b / a - -/(a::SN, b::Number) = (isone(abs(b)) ? b : (b isa Integer ? 1//b : inv(b))) * a +/(a::Union{SN, Number}, b::SN) = _Div(a, b) +/(a::SN, b::Number) = (isone(abs(b)) ? b : (b isa Integer ? 1 // b : inv(b))) * a //(a::Union{SN, Number}, b::SN) = a / b - //(a::SN, b::T) where {T <: Number} = (one(T) // b) * a - -### -### Pow -### +\(a::SN, b::Union{Number, SN}) = b / a +\(a::Number, b::SN) = b / a function ^(a::SN, b) - !issafecanon(^, a,b) && return Pow(a, b) + !issafecanon(^, a, b) && return Pow(a, b) if b isa Number && iszero(b) - # fast path 1 elseif b isa Number && b < 0 - Div(1, a ^ (-b)) + _Div(1, a^(-b)) elseif ismul(a) && b isa Number coeff = unstable_pow(a.coeff, b) - Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b*v, a.dict)) + _Mul(promote_symtype(^, symtype(a), symtype(b)), + coeff, mapvalues((k, v) -> b * v, a.dict)) else Pow(a, b) end end - -^(a::Number, b::SN) = Pow(a, b) +^(a::Number, b::SN) = _Pow(a, b) From 0b0b2ba4de8b1c29b431435e630c1644dd0a8b47 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 13:08:56 -0700 Subject: [PATCH 048/140] Modify `show` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index cd35d4652..a4c2ce8de 100644 --- a/src/types.jl +++ b/src/types.jl @@ -871,7 +871,7 @@ showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) if issym(v) - Base.show_unquoted(io, v.name) + Base.show_unquoted(io, v.impl.name) else show_term(io, v) end From 3cea20f88109a5c9ef49704024db4c939335cd21 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 28 Jun 2024 15:07:41 -0700 Subject: [PATCH 049/140] Fix calling `term` [skip ci] --- src/methods.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index aa760a3bf..e78bac63f 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -177,9 +177,9 @@ for (f, Domain) in [(==) => Number, (!=) => Number, xor => Bool] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) - (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b, type=Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b; T = Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b; T = Bool) + (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b; T = Bool) end end From 07ffb87e0b15ad2881b745724ff3fa041ce46abd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:20:45 -0700 Subject: [PATCH 050/140] Fix `isone` & `iszero` for `BasicSymbolic` --- src/types.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index a4c2ce8de..ef2c00504 100644 --- a/src/types.jl +++ b/src/types.jl @@ -153,7 +153,7 @@ function arguments(x::BasicSymbolic) args = impl.arguments isempty(args) || return args siz = length(impl.dict) - idcoeff = E === ADD ? iszero(impl.coeff) : isone(impl.coeff) + idcoeff = E === ADD ? _iszero(impl.coeff) : _isone(impl.coeff) sizehint!(args, idcoeff ? siz : siz + 1) idcoeff || push!(args, impl.coeff) if isadd(x) @@ -377,6 +377,13 @@ function _iszero(x::BasicSymbolic) end end +function _isone(x::BasicSymbolic) + @match x.impl begin + Const(_...) => isone(x.impl.val) + _ => false + end +end + const Rat = Union{Rational, Integer} function ratcoeff(x) From 014c09718a0de370c0803fdfcd67ea3c0212c880 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:24:33 -0700 Subject: [PATCH 051/140] Add `convert` methods for `BasicSymbolic` --- src/types.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/types.jl b/src/types.jl index ef2c00504..e8a863d60 100644 --- a/src/types.jl +++ b/src/types.jl @@ -340,6 +340,13 @@ function _Const(val::T; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end +function Base.convert(::Type{BasicSymbolic}, x) + _Const(x) +end +function Base.convert(::Type{BasicSymbolic}, x::BasicSymbolic) + x +end + function _Add(::Type{T}, coeff, dict; kwargs...) where {T} if isempty(dict) return coeff From c556359a14d1ce1015917d1b961b9e5cdd912774 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:28:24 -0700 Subject: [PATCH 052/140] Fix field access in `+` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index e8a863d60..d69427eb8 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1107,7 +1107,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + b.coeff, b.dict) + _Add(add_t(a, b), a + b.impl.coeff, b.impl.dict) else _Add(add_t(a, b), makeadd(1, a, b)...) end From 621780986b9b485cc8855a3ce94ceffbb1a65c2a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:36:04 -0700 Subject: [PATCH 053/140] Fix `isexpr` for `Const` --- src/types.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index d69427eb8..1f52a5ac2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -172,7 +172,14 @@ function arguments(x::BasicSymbolic) return args end -isexpr(s::BasicSymbolic) = !issym(s) +function isexpr(x::BasicSymbolic) + @match x.impl begin + Sym(_...) => false + Const(_...) => false + _ => false + end +end + iscall(s::BasicSymbolic) = isexpr(s) function issym(x) From 4646d122b0979a631000ff68605dbd597446e9ca Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:36:36 -0700 Subject: [PATCH 054/140] Fix `Base.show` for `Const` --- src/types.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index 1f52a5ac2..d7970b8e7 100644 --- a/src/types.jl +++ b/src/types.jl @@ -891,10 +891,10 @@ showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) - if issym(v) - Base.show_unquoted(io, v.impl.name) - else - show_term(io, v) + @match v.impl begin + Sym(_...) => Base.show_unquoted(io, v.impl.name) + Const(_...) => print(io, v.impl.val) + _ => show_term(io, v) end end From 98ec69268150ff6854881a9e57fc4a5680d57d03 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:53:33 -0700 Subject: [PATCH 055/140] Fix `issafecanon` for `Const` and add docstring --- src/types.jl | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/types.jl b/src/types.jl index d7970b8e7..bb5861b20 100644 --- a/src/types.jl +++ b/src/types.jl @@ -629,19 +629,38 @@ function hasmetadata(s::Symbolic, ctx) metadata(s) isa AbstractDict && haskey(metadata(s), ctx) end -issafecanon(f, s) = true -function issafecanon(f, s::Symbolic) - if isnothing(metadata(s)) || issym(s) - return true - else - _issafecanon(f, s) +""" +$(TYPEDSIGNATURES) + +Check if the symbolic expression(s) is/are safe to canonicalize with respect to the function `f`. + +This function determines if applying the canonicalization rules associated with function `f` +to the symbolic expression `s` is safe and won't lead to incorrect simplifications. It handles various cases +depending on the type of `s` and the function `f`. + +For multiple arguments, `issafecanon(f, ss...)`, it checks if canonicalization is safe for all expressions in `ss`. + +# Arguments +- `f`: The function for which canonicalization safety is being checked. +- `s`: The symbolic expression to check. +- `ss...`: A variable number of symbolic expressions to check. + +# Returns +- `true` if canonicalization is safe, `false` otherwise. +""" +function issafecanon(f, s::BasicSymbolic) + isnothing(metadata(s)) || @match s.impl begin + Sym(_...) => true + Const(_...) => true + _ => _issafecanon(f, s) end end -_issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+,*,^)) -_issafecanon(::typeof(+), s) = !iscall(s) || !(operation(s) in (+,*)) -_issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) +issafecanon(f, s) = true +issafecanon(f, ss...) = all(x -> issafecanon(f, x), ss) -issafecanon(f, ss...) = all(x->issafecanon(f, x), ss) +_issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+, *, ^)) +_issafecanon(::typeof(+), s) = !iscall(s) || !(operation(s) in (+, *)) +_issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) function getmetadata(s::Symbolic, ctx) md = metadata(s) From ed3fcc3ed11344615f8981e4e9226cd00a1cc995 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 1 Jul 2024 15:58:34 -0700 Subject: [PATCH 056/140] Modify `+` for `Const` [skip ci] --- src/types.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/types.jl b/src/types.jl index bb5861b20..2b0803df0 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1130,6 +1130,9 @@ function +(a::SN, b::SN) _Add(add_t(a, b), coeff, dict) end function +(a::Number, b::SN) + if isconst(b) + return _Const(a + b.impl.val) + end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) From f519322f6f9c457be7505bd718a75a3a4fda5f6c Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 3 Aug 2024 22:45:15 -0700 Subject: [PATCH 057/140] Remove unreachable error case in type matching --- src/types.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index d8a00350b..4fadee46d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -61,7 +61,6 @@ function exprtype(x::BasicSymbolic) Div(_...) => DIV Pow(_...) => POW Const(_...) => CONST - _ => error_on_type() end end @@ -111,7 +110,6 @@ symtype(x) = typeof(x) Pow(_...) => (^) Sym(_...) => error_sym() Const(_...) => error_const() - _ => error_on_type() end end @@ -154,7 +152,6 @@ function TermInterface.arguments(x::BasicSymbolic) Pow(_...) => @goto DIVPOW Sym(_...) => error_sym() Const(_...) => error_const() - _ => error_on_type() end @label ADDMUL From e77e47089f59cb53cdd5fbb3e81445e34ece7ef1 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 3 Aug 2024 23:46:21 -0700 Subject: [PATCH 058/140] Fix function argument type --- src/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 812e229fb..438a6ad7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -179,10 +179,10 @@ Base.length(l::LL) = length(l.v)-l.i+1 @inline car(l::LL) = l.v[l.i] @inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i+1) -Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY -Base.isempty(t::Term) = false -@inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +Base.length(t::BasicSymbolic) = length(arguments(t)) + 1 # PIRACY +Base.isempty(t::BasicSymbolic) = false +@inline car(t::BasicSymbolic) = operation(t) +@inline cdr(t::BasicSymbolic) = arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) From 8840e23b41354627346b14ace4a5f1a33dec133b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 3 Aug 2024 23:54:14 -0700 Subject: [PATCH 059/140] Update argument name in `makepattern` `Expr` [skip ci] --- src/rule.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rule.jl b/src/rule.jl index e1531bfe2..cf78ed948 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -67,10 +67,10 @@ function makepattern(expr, keys) makeslot(expr.args[2], keys) end else - :(term($(map(x->makepattern(x, keys), expr.args)...); type=Any)) + :(term($(map(x -> makepattern(x, keys), expr.args)...); T = Any)) end elseif expr.head === :ref - :(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any)) + :(term(getindex, $(map(x -> makepattern(x, keys), expr.args)...); T = Any)) elseif expr.head === :$ return esc(expr.args[1]) else From 145c3033ab9c48220970d3f8f03ca750efb1e346 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 4 Aug 2024 10:23:18 -0700 Subject: [PATCH 060/140] Replace deprecated `Term` constructor --- src/methods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods.jl b/src/methods.jl index e78bac63f..b0b222677 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -186,7 +186,7 @@ end for f in [!, ~] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool - (::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s]) + (::$(typeof(f)))(s::Symbolic{Bool}) = _Term(Bool, !, [s]) end end From 1c2371042b9866f58300b39fd0e071e0117e391b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 4 Aug 2024 10:27:16 -0700 Subject: [PATCH 061/140] Remove unnecessary tests for `Term` constructor --- test/types.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/types.jl b/test/types.jl index 899d598df..551fd8d86 100644 --- a/test/types.jl +++ b/test/types.jl @@ -23,9 +23,6 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @test t1.f == sin @test t1.arguments == [bs1] @test typeof(t1.arguments) == Vector{BasicSymbolic} - @test_throws MethodError Term(sin, [s1]) - @test_throws MethodError Term(sin, [1]) - @test_throws MethodError Term(sin, [2.0]) end @testset "Div" begin d1 = Div(num = bs1, den = bs2) From 4547bde348bd85f6fdd26c35070ba45b0f52de8d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 4 Aug 2024 10:29:28 -0700 Subject: [PATCH 062/140] Remove unnecessary tests for `Pow` constructor --- test/types.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/types.jl b/test/types.jl index 551fd8d86..c34a3ff50 100644 --- a/test/types.jl +++ b/test/types.jl @@ -59,9 +59,6 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add p2 = Pow(; base, exp) @test p2.base == bs1 @test p2.exp == bs2 - @test_throws MethodError Pow(base = s1, exp = bs2) - @test_throws MethodError Pow(base = bs1, exp = s2) - @test_throws MethodError Pow(base = s1, exp = s2) end c1 = Const(1) bc1 = BasicSymbolic{Int}(impl = c1) From 99549b80c10f7191e1c83210cbe19ee6212d3f40 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 4 Aug 2024 10:30:14 -0700 Subject: [PATCH 063/140] Remove unnecessary tests for `Div` constructor [skip ci] --- test/types.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/types.jl b/test/types.jl index c34a3ff50..1c7f3cd03 100644 --- a/test/types.jl +++ b/test/types.jl @@ -41,9 +41,6 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add d2 = Div(; num, den) @test d2.num == bs1 @test d2.den == bs2 - @test_throws MethodError Div(num = s1, den = bs2) - @test_throws MethodError Div(num = bs1, den = s2) - @test_throws MethodError Div(num = s1, den = s2) end @testset "Pow" begin p1 = Pow(base = bs1, exp = bs2) From df99cff2b9dd426970dd0788010443ee342c2fde Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 5 Aug 2024 15:24:00 -0700 Subject: [PATCH 064/140] Change `==` to `isequal` for testing --- test/types.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/types.jl b/test/types.jl index 1c7f3cd03..0c7911274 100644 --- a/test/types.jl +++ b/test/types.jl @@ -21,7 +21,7 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add t1 = Term(sin, [bs1]) @test typeof(t1) == BasicSymbolicImpl @test t1.f == sin - @test t1.arguments == [bs1] + @test isequal(t1.arguments, [bs1]) @test typeof(t1.arguments) == Vector{BasicSymbolic} end @testset "Div" begin @@ -29,33 +29,33 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @test typeof(d1) == BasicSymbolicImpl @test typeof(d1.num) == BasicSymbolic{Float64} @test typeof(d1.den) == BasicSymbolic{Int64} - @test d1.num == bs1 - @test d1.den == bs2 + @test isequal(d1.num, bs1) + @test isequal(d1.den, bs2) @test typeof(d1.simplified) == Base.RefValue{Bool} @test isassigned(d1.simplified) @test !d1.simplified[] @test typeof(d1.arguments) == Vector{BasicSymbolic} - @test d1.arguments == [bs1, bs2] + @test isequal(d1.arguments, [bs1, bs2]) num = bs1 den = bs2 d2 = Div(; num, den) - @test d2.num == bs1 - @test d2.den == bs2 + @test isequal(d2.num, bs1) + @test isequal(d2.den, bs2) end @testset "Pow" begin p1 = Pow(base = bs1, exp = bs2) @test typeof(p1) == BasicSymbolicImpl @test typeof(p1.base) == BasicSymbolic{Float64} @test typeof(p1.exp) == BasicSymbolic{Int64} - @test p1.base == bs1 - @test p1.exp == bs2 + @test isequal(p1.base, bs1) + @test isequal(p1.exp, bs2) @test typeof(p1.arguments) == Vector{BasicSymbolic} - @test p1.arguments == [bs1, bs2] + @test isequal(p1.arguments, [bs1, bs2]) base = bs1 exp = bs2 p2 = Pow(; base, exp) - @test p2.base == bs1 - @test p2.exp == bs2 + @test isequal(p2.base, bs1) + @test isequal(p2.exp, bs2) end c1 = Const(1) bc1 = BasicSymbolic{Int}(impl = c1) @@ -132,7 +132,7 @@ end @test t.metadata == SymbolicUtils.NO_METADATA @test t.hash[] == SymbolicUtils.EMPTY_HASH @test t.impl.f == mod - @test t.impl.arguments == [s1, s2] + @test isequal(t.impl.arguments, [s1, s2]) end @testset "Const" begin c1 = _Const(1.0) From f6f557bd20c6de410b8ac0afb8ac9728f8f8fc77 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 5 Aug 2024 15:25:12 -0700 Subject: [PATCH 065/140] Adapt `@syms` tests to new type structure --- test/basics.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index e59b008e3..888ffb8f9 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -9,16 +9,16 @@ using Test @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int @test issym(a) && symtype(a) == Number - @test a.name === :a + @test a.impl.name === :a @test issym(b) && symtype(b) == Float64 @test nameof(b) === :b @test issym(f) - @test f.name === :f + @test f.impl.name === :f @test issym(g) - @test g.name === :g + @test g.impl.name === :g @test isterm(f(b)) @test symtype(f(b)) === Number From 87bee712ce38b18c42e360db04686beb4fd530cd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 5 Aug 2024 21:48:27 -0700 Subject: [PATCH 066/140] Revert field type of `Div` and `Pow` from `BasicSymbolic` to `Any` --- src/types.jl | 12 ++++++------ test/types.jl | 6 ------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/types.jl b/src/types.jl index 4fadee46d..3ab33d7db 100644 --- a/src/types.jl +++ b/src/types.jl @@ -27,15 +27,15 @@ const EMPTY_HASH = UInt(0) issorted::RefValue{Bool} = Ref(false) end struct Div - num::BasicSymbolic - den::BasicSymbolic + num::Any + den::Any simplified::RefValue{Bool} = Ref(false) - arguments::Vector{BasicSymbolic} = [num, den] + arguments::Vector{Any} = [num, den] end struct Pow - base::BasicSymbolic - exp::BasicSymbolic - arguments::Vector{BasicSymbolic} = [base, exp] + base::Any + exp::Any + arguments::Vector{Any} = [base, exp] end struct Const val::Any diff --git a/test/types.jl b/test/types.jl index 0c7911274..43bb1f76c 100644 --- a/test/types.jl +++ b/test/types.jl @@ -27,14 +27,11 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @testset "Div" begin d1 = Div(num = bs1, den = bs2) @test typeof(d1) == BasicSymbolicImpl - @test typeof(d1.num) == BasicSymbolic{Float64} - @test typeof(d1.den) == BasicSymbolic{Int64} @test isequal(d1.num, bs1) @test isequal(d1.den, bs2) @test typeof(d1.simplified) == Base.RefValue{Bool} @test isassigned(d1.simplified) @test !d1.simplified[] - @test typeof(d1.arguments) == Vector{BasicSymbolic} @test isequal(d1.arguments, [bs1, bs2]) num = bs1 den = bs2 @@ -45,11 +42,8 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @testset "Pow" begin p1 = Pow(base = bs1, exp = bs2) @test typeof(p1) == BasicSymbolicImpl - @test typeof(p1.base) == BasicSymbolic{Float64} - @test typeof(p1.exp) == BasicSymbolic{Int64} @test isequal(p1.base, bs1) @test isequal(p1.exp, bs2) - @test typeof(p1.arguments) == Vector{BasicSymbolic} @test isequal(p1.arguments, [bs1, bs2]) base = bs1 exp = bs2 From 0db3e9f9904b0b17f66b5f952cb36b0f6ef14653 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 5 Aug 2024 21:58:12 -0700 Subject: [PATCH 067/140] Fix `Pow` function call in `^` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 3ab33d7db..6cf8a10eb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1260,7 +1260,7 @@ function ^(a::SN, b) _Mul(promote_symtype(^, symtype(a), symtype(b)), coeff, mapvalues((k, v) -> b * v, a.dict)) else - Pow(a, b) + Pow(base = a, exp = b) end end ^(a::Number, b::SN) = _Pow(a, b) From 3f34dc0415df6e9a2023f5bb93485304191b8f32 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 5 Aug 2024 23:18:28 -0700 Subject: [PATCH 068/140] Replace `Unityper.rt_constructor` with `BasicSymbolic` constructor [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 6cf8a10eb..6346420c5 100644 --- a/src/types.jl +++ b/src/types.jl @@ -80,7 +80,7 @@ const SIMPLIFIED = 0x01 << 0 function ConstructionBase.setproperties_object(obj::BasicSymbolic{T}, patch)::BasicSymbolic{T} where T nt = getproperties(obj) nt_new = merge(nt, patch) - Unityper.rt_constructor(obj){T}(;nt_new...) + BasicSymbolic{T}(; nt_new...) end ### From 4cb4323bbbe48ce01103b552f0634922e75ca1f0 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 6 Aug 2024 00:42:22 -0700 Subject: [PATCH 069/140] Fix `Pow` construction in `^` [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 6346420c5..31e44c21f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1260,7 +1260,7 @@ function ^(a::SN, b) _Mul(promote_symtype(^, symtype(a), symtype(b)), coeff, mapvalues((k, v) -> b * v, a.dict)) else - Pow(base = a, exp = b) + _Pow(a, b) end end ^(a::Number, b::SN) = _Pow(a, b) From 4beca4443fd86e8e33dce7ee985e29bfe52bb2aa Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 6 Aug 2024 00:48:13 -0700 Subject: [PATCH 070/140] Fix typo in `Pow` constructor definition --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 31e44c21f..8e8b7fa73 100644 --- a/src/types.jl +++ b/src/types.jl @@ -492,7 +492,7 @@ function _Pow(::Type{T}, base, exp; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end function _Pow(base, exp; kwargs...) - Pow{promote_symtype(^, symtype(base), symtype(b))}(makepow(base, exp)..., kwargs...) + Pow{promote_symtype(^, symtype(base), symtype(exp))}(makepow(base, exp)..., kwargs...) end function toterm(t::BasicSymbolic{T}) where {T} From a026af77c6fc8184e3b1a925fb27708a6715a8f6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 6 Aug 2024 23:05:52 -0700 Subject: [PATCH 071/140] Fix `Pow` construction [skip ci] --- src/types.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index 8e8b7fa73..d5a284f25 100644 --- a/src/types.jl +++ b/src/types.jl @@ -488,11 +488,11 @@ end function _Pow(::Type{T}, base, exp; kwargs...) where {T} _iszero(exp) && return 1 _isone(exp) && return a - impl = (; base, exp) + impl = Pow(; base, exp) BasicSymbolic{T}(; impl, kwargs...) end function _Pow(base, exp; kwargs...) - Pow{promote_symtype(^, symtype(base), symtype(exp))}(makepow(base, exp)..., kwargs...) + _Pow(promote_symtype(^, symtype(base), symtype(exp)), makepow(base, exp)..., kwargs...) end function toterm(t::BasicSymbolic{T}) where {T} From 26c71502a93560365de07c69e8c20133cb1d69db Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 6 Aug 2024 23:48:17 -0700 Subject: [PATCH 072/140] Fix `isexpr` for composite symbolic expressions [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index d5a284f25..dd216996e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -182,7 +182,7 @@ function isexpr(x::BasicSymbolic) @match x.impl begin Sym(_...) => false Const(_...) => false - _ => false + _ => true end end From 98d5c38b924d74bfbbf7bf88d7fe4458d10dc258 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 7 Aug 2024 00:19:40 -0700 Subject: [PATCH 073/140] Revert `Add` & `Mul` `coeff` type from `BasicSymbolic` to `Any` --- src/types.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index dd216996e..531c2880e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -15,13 +15,13 @@ const EMPTY_HASH = UInt(0) arguments::Vector{BasicSymbolic} end struct Add - coeff::BasicSymbolic + coeff::Any dict::Dict{BasicSymbolic, Any} arguments::Vector{BasicSymbolic} = BasicSymbolic[] issorted::RefValue{Bool} = Ref(false) end struct Mul - coeff::BasicSymbolic + coeff::Any dict::Dict{BasicSymbolic, Any} arguments::Vector{BasicSymbolic} = BasicSymbolic[] issorted::RefValue{Bool} = Ref(false) From 3e73790b1c41d2f278d3c07e7dbe39127d0a858d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 7 Aug 2024 00:46:09 -0700 Subject: [PATCH 074/140] Unroll `Const` in addition operation --- src/types.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 531c2880e..11712ac2f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1145,6 +1145,12 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) + if isconst(a) + return a.impl.val + b + end + if isconst(b) + return b.impl.val + a + end !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( @@ -1160,7 +1166,7 @@ function +(a::SN, b::SN) end function +(a::Number, b::SN) if isconst(b) - return _Const(a + b.impl.val) + return a + b.impl.val end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b From b6f6456c04c0bf5892e142b66f9547795f7256c6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 7 Aug 2024 00:46:43 -0700 Subject: [PATCH 075/140] Fix `term` keyword argument function call [skip ci] --- test/basics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basics.jl b/test/basics.jl index 888ffb8f9..c3a7bb6a3 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -91,7 +91,7 @@ struct Ctx2 end @test isequal(substitute(1+sqrt(a), Dict(a => 2), fold=false), - 1 + term(sqrt, 2, type=Number)) + 1 + term(sqrt, 2, T=Number)) @test substitute(1+sqrt(a), Dict(a => 2), fold=true) isa Float64 end From 59bc77afc3043e45fbad9c0df3efed3a6f64eebb Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 7 Aug 2024 23:48:38 -0700 Subject: [PATCH 076/140] Fix `Add` construction in tests --- test/basics.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index c3a7bb6a3..9ff00c2c3 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -98,11 +98,11 @@ end @testset "Base methods" begin @syms w::Complex z::Complex a::Real b::Real x - @test isequal(w + z, Add(Complex, 0, Dict(w=>1, z=>1))) - @test isequal(z + a, Add(Number, 0, Dict(z=>1, a=>1))) - @test isequal(a + b, Add(Real, 0, Dict(a=>1, b=>1))) - @test isequal(a + x, Add(Number, 0, Dict(a=>1, x=>1))) - @test isequal(a + z, Add(Number, 0, Dict(a=>1, z=>1))) + @test isequal(w + z, _Add(Complex, 0, Dict(w => 1, z => 1))) + @test isequal(z + a, _Add(Number, 0, Dict(z => 1, a => 1))) + @test isequal(a + b, _Add(Real, 0, Dict(a => 1, b => 1))) + @test isequal(a + x, _Add(Number, 0, Dict(a => 1, x => 1))) + @test isequal(a + z, _Add(Number, 0, Dict(a => 1, z => 1))) foo(w, z, a, b) = 1.0 SymbolicUtils.promote_symtype(::typeof(foo), args...) = Real From 2031008ec41b72bfeacceaf2229fb7da1bd317f7 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 8 Aug 2024 00:19:52 -0700 Subject: [PATCH 077/140] Fix `Term` construction in tests --- test/basics.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index 9ff00c2c3..7b1627535 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -109,24 +109,24 @@ end @test SymbolicUtils._promote_symtype(foo, (w, z, a, b,)) === Real # promote_symtype of identity - @test isequal(Term(identity, [w]), Term{Complex}(identity, [w])) + @test isequal(_Term(identity, [w]), _Term(Complex, identity, [w])) @test isequal(+(w), w) @test isequal(+(a), a) - @test isequal(rem2pi(a, RoundNearest), Term{Real}(rem2pi, [a, RoundNearest])) + @test isequal(rem2pi(a, RoundNearest), _Term(Real, rem2pi, [a, RoundNearest])) # bool for f in [(==), (!=), (<=), (>=), (<), (>)] - @test isequal(f(a, 0), Term{Bool}(f, [a, 0])) - @test isequal(f(0, a), Term{Bool}(f, [0, a])) - @test isequal(f(a, a), Term{Bool}(f, [a, a])) + @test isequal(f(a, 0), _Term(Bool, f, [a, 0])) + @test isequal(f(0, a), _Term(Bool, f, [0, a])) + @test isequal(f(a, a), _Term(Bool, f, [a, a])) end @test symtype(ifelse(true, 4, 5)) == Int @test symtype(ifelse(a < 0, b, w)) == Union{Real, Complex} @test SymbolicUtils.promote_symtype(ifelse, Bool, Int, Bool) == Union{Int, Bool} @test_throws MethodError w < 0 - @test isequal(w == 0, Term{Bool}(==, [w, 0])) + @test isequal(w == 0, _Term(Bool, ==, [w, 0])) @eqtest x // 5 == (1 // 5) * x @eqtest (1//2 * x) / 5 == (1 // 10) * x From 903f95c64140669062866433fa09a8018da7fc40 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 8 Aug 2024 00:20:02 -0700 Subject: [PATCH 078/140] Fix `Term` construction in `ifelse` --- src/methods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods.jl b/src/methods.jl index b0b222677..4128ca3ef 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -194,7 +194,7 @@ end # An ifelse node, ifelse is a built-in unfortunately # So this uses IfElse.jl's ifelse that we imported function ifelse(_if::Symbolic{Bool}, _then, _else) - Term{Union{symtype(_then), symtype(_else)}}(ifelse, Any[_if, _then, _else]) + _Term(Union{symtype(_then), symtype(_else)}, ifelse, Any[_if, _then, _else]) end promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T, S} From 289ad317ee7c51f6ee08f7012f61663deb8c8b2c Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 8 Aug 2024 00:20:26 -0700 Subject: [PATCH 079/140] Fix `Div` constructor --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 11712ac2f..9b2cc0767 100644 --- a/src/types.jl +++ b/src/types.jl @@ -475,7 +475,7 @@ function _Div(::Type{T}, num, den; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end function _Div(num, den; kwargs...) - Div{promote_symtype((/), symtype(num), symtype(den))}(num, den; kwargs...) + _Div(promote_symtype((/), symtype(num), symtype(den)), num, den; kwargs...) end @inline function numerators(x) From f73271857b56057dc785061401f84be96662ea25 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 8 Aug 2024 00:49:50 -0700 Subject: [PATCH 080/140] Fix `Div` constructor [skip ci] --- src/types.jl | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/types.jl b/src/types.jl index 9b2cc0767..0cc3f4024 100644 --- a/src/types.jl +++ b/src/types.jl @@ -449,28 +449,35 @@ function _Div(::Type{T}, num, den; kwargs...) where {T} return _Div(T, num.impl.num * den.impl.den, num.impl.den * den.impl.num) elseif isdiv(num) return _Div(T, num.impl.num, num.impl.den * den) - elseif isdiv(d) + elseif isdiv(den) return _Div(T, num * den.impl.den, den.impl.num) end - d isa Number && _isone(-d) && return -1 * n - n isa Rat && d isa Rat && return n // d # maybe called by oblivious code in simplify + if den isa Number && _isone(-den) + return -1 * num + end + if num isa Rat && den isa Rat + return num // den # maybe called by oblivious code in simplify + end # GCD coefficient upon construction - rat, nc = ratcoeff(n) + rat, nc = ratcoeff(num) if rat - rat, dc = ratcoeff(d) + rat, dc = ratcoeff(den) if rat g = gcd(nc, dc) * sign(dc) # make denominator positive invdc = ratio(1, g) - n = maybe_intcoeff(invdc * n) - d = maybe_intcoeff(invdc * d) - if d isa Number - _isone(d) && return n - _isone(-d) && return -1 * n + num = maybe_intcoeff(invdc * num) + den = maybe_intcoeff(invdc * den) + if den isa Number + if _isone(den) + return num + end + if _isone(-den) + return -1 * num + end end end end - impl = Div(; num, den) BasicSymbolic{T}(; impl, kwargs...) end From 1e9c99022d9ca900b6d7e85e6a984414bcc8d5eb Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 8 Aug 2024 16:47:17 -0700 Subject: [PATCH 081/140] Fix `ConstructionBase.setproperties_object` for changing fields [skip ci] --- src/types.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/types.jl b/src/types.jl index 0cc3f4024..e6fed2725 100644 --- a/src/types.jl +++ b/src/types.jl @@ -77,10 +77,22 @@ const SIMPLIFIED = 0x01 << 0 #@inline is_of_type(x::BasicSymbolic, type::UInt8) = (x.bitflags & type) != 0x00 #@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) -function ConstructionBase.setproperties_object(obj::BasicSymbolic{T}, patch)::BasicSymbolic{T} where T - nt = getproperties(obj) - nt_new = merge(nt, patch) - BasicSymbolic{T}(; nt_new...) +function ConstructionBase.setproperties_object( + obj::BasicSymbolic{T}, patch)::BasicSymbolic{T} where {T} + nt1 = getproperties(obj) + nt2 = getproperties(obj.impl) + nt1 = merge(nt1, patch) + nt2 = merge(nt2, patch) + metadata = nt1.metadata + @match obj.impl begin + Sym(_...) => _Sym(T, nt2.name; metadata) + Term(_...) => _Term(T, nt2.f, nt2.arguments; metadata) + Add(_...) => _Add(T, nt2.coeff, nt2.dict; metadata) + Mul(_...) => _Mul(T, nt2.coeff, nt2.dict; metadata) + Div(_...) => _Div(T, nt2.num, nt2.den; metadata) + Pow(_...) => _Pow(T, nt2.base, nt2.exp; metadata) + Const(_...) => _Const(nt2.val; metadata) + end end ### From 1cb71a968641952da70356b525d40ad44bfb8ad9 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 22:48:56 -0700 Subject: [PATCH 082/140] Adapt `getproperty` for new class structure --- src/types.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index e6fed2725..21cc5b3ea 100644 --- a/src/types.jl +++ b/src/types.jl @@ -554,7 +554,7 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = _Mul(symtype(x), 1, x.dict) + k = _Mul(symtype(x), 1, x.impl.dict) v = sign * x.impl.coeff + get(d, k, 0) else k = x @@ -1173,10 +1173,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), a.coeff + b.coeff, _merge(+, a.dict, b.dict, filter = _iszero)) + add_t(a, b), a.impl.coeff + b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), a.coeff + coeff, _merge(+, a.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), a.impl.coeff + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) elseif isadd(b) return b + a end From 34b0796d2f5dcadebd56e1428674164b9d900c51 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 22:49:35 -0700 Subject: [PATCH 083/140] Fix `Pow` constructor --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 21cc5b3ea..6552b97eb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -506,7 +506,7 @@ end function _Pow(::Type{T}, base, exp; kwargs...) where {T} _iszero(exp) && return 1 - _isone(exp) && return a + _isone(exp) && return base impl = Pow(; base, exp) BasicSymbolic{T}(; impl, kwargs...) end From 9f6b4262cd2ed87f5abc3f1c24ab28ccb7bea5c6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 22:57:00 -0700 Subject: [PATCH 084/140] Fix test case due keyword argument name change --- test/basics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basics.jl b/test/basics.jl index 7b1627535..e42cced99 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -173,7 +173,7 @@ end @syms a b c @test repr(a+b) == "a + b" @test repr(-a) == "-a" - @test repr(term(-, a; type = Real)) == "-a" + @test repr(term(-, a; T = Real)) == "-a" @test repr(-a + 3) == "3 - a" @test repr(-(a + b)) == "-a - b" @test repr((2a)^(-2a)) == "(2a)^(-2a)" From c3743f6a474959b39e789e607e54004f0d68e91a Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:15:16 -0700 Subject: [PATCH 085/140] Adapt `getproperty` to new class structure --- src/polyform.jl | 2 +- src/types.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..e5c2b50e8 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -470,7 +470,7 @@ end # ismul(x) function quick_mul(x, y) - if haskey(x.dict, y) && x.dict[y] >= 1 + if haskey(x.impl.dict, y) && x.dict[y] >= 1 d = copy(x.dict) if d[y] > 1 d[y] -= 1 diff --git a/src/types.jl b/src/types.jl index 6552b97eb..012016257 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1200,13 +1200,13 @@ end function -(a::SN) !issafecanon(*, a) && return term(-, a) - isadd(a) ? _Add(sub_t(a), -a.coeff, mapvalues((_, v) -> -v, a.dict)) : + isadd(a) ? _Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict)) : _Add(sub_t(a), makeadd(-1, 0, a)...) end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), a.coeff - b.coeff, _merge(-, a.dict, b.dict, filter = _iszero)) + _Add(sub_t(a, b), a.impl.coeff - b.impl.coeff, _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) else a + (-b) end From 0b1eaaee0c931ae37c22a68947aa769f874da1f0 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:18:27 -0700 Subject: [PATCH 086/140] Fix `Mul` printing --- src/types.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/types.jl b/src/types.jl index 012016257..b62699655 100644 --- a/src/types.jl +++ b/src/types.jl @@ -861,15 +861,21 @@ end function show_mul(io, args) length(args) == 1 && return print_arg(io, *, args[1]) - minus = args[1] isa Number && args[1] == -1 - unit = args[1] isa Number && args[1] == 1 + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.impl.val + end + + minus = arg1 isa Number && arg1 == -1 + unit = arg1 isa Number && arg1 == 1 - paren_scalar = (args[1] isa Complex && !_iszero(imag(args[1]))) || + paren_scalar = (arg1 isa Complex && !_iszero(imag(arg1))) || args[1] isa Rational || - (args[1] isa Number && !isfinite(args[1])) + (arg1 isa Number && !isfinite(arg1)) nostar = minus || unit || - (!paren_scalar && args[1] isa Number && !(args[2] isa Number)) + (!paren_scalar && arg1 isa Number && + !(isconst(args[2]) && args[2].impl.val isa Number)) for (i, t) in enumerate(args) if i != 1 From f2234c1e27ee7b7c6671eed6bd0a4523aca2b3cf Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:33:39 -0700 Subject: [PATCH 087/140] Fix `get` `coeff` to new class structure --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index b62699655..dc561366d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -428,7 +428,7 @@ const Rat = Union{Rational, Integer} function ratcoeff(x) if ismul(x) - ratcoeff(x.coeff) + ratcoeff(x.impl.coeff) elseif x isa Rat (true, x) else From fb77a933a6cb48e494336bf60185ea59975d6dbf Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:34:00 -0700 Subject: [PATCH 088/140] Fix printing negative term in `Add` --- src/types.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/types.jl b/src/types.jl index dc561366d..e6be4b073 100644 --- a/src/types.jl +++ b/src/types.jl @@ -789,6 +789,10 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) + if isconst(t) + val = t.impl.val + return isnegative(val) + end if iscall(t) && operation(t) === (*) coeff = first(arguments(t)) return isnegative(coeff) @@ -823,8 +827,12 @@ function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) args = arguments(t) - @assert args[1] < 0 - Any[-args[1], args[2:end]...] + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.impl.val + end + @assert arg1 < 0 + Any[-arg1, args[2:end]...] end From 4792348475bfef7955247aa8a8951453ba816d23 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:41:43 -0700 Subject: [PATCH 089/140] Adapt `maybe_intcoeff` to new class structure --- src/types.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index e6be4b073..fac97bffd 100644 --- a/src/types.jl +++ b/src/types.jl @@ -439,13 +439,14 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) - if x.coeff isa Rational && isone(x.coeff.den) - Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false)) + coeff = x.impl.coeff + if coeff isa Rational && isone(denominator(coeff)) + _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) else x end elseif x isa Rational - isone(x.den) ? x.num : x + isone(denominator(x)) ? numerator(x) : x else x end From 9133330f7e7799560dc305ba53684528fc117315 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:50:14 -0700 Subject: [PATCH 090/140] Remove unnecessary imports in test --- test/basics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basics.jl b/test/basics.jl index e42cced99..504630449 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term +using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, BasicSymbolic, term using SymbolicUtils using IfElse: ifelse using Setfield From 7ab12c3d7f9dfc7eb12bcab940d6e9fee370b089 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:50:33 -0700 Subject: [PATCH 091/140] Fix `Term` construction in tests --- test/basics.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index 504630449..e71fc0ab1 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -179,8 +179,8 @@ end @test repr((2a)^(-2a)) == "(2a)^(-2a)" @test repr(1/2a) == "1 / (2a)" @test repr(2/(2*a)) == "1 / a" - @test repr(Term(*, [1, 1])) == "1" - @test repr(Term(*, [2, 1])) == "2*1" + @test repr(_Term(*, [1, 1])) == "1" + @test repr(_Term(*, [2, 1])) == "2*1" @test repr((a + b) - (b + c)) == "a - c" @test repr(a + -1*(b + c)) == "a - b - c" @test repr(a + -1*b) == "a - b" @@ -251,13 +251,13 @@ end @test symtype(new_expr) == Int64 end -toterm(t) = Term{symtype(t)}(operation(t), arguments(t)) +toterm(t) = _Term(symtype(t), operation(t), arguments(t)) @testset "diffs" begin @syms a b c - @test isequal(toterm(-1c), Term{Number}(*, [-1, c])) - @test isequal(toterm(-1(a+b)), Term{Number}(+, [-1a, -b])) - @test isequal(toterm((a + b) - (b + c)), Term{Number}(+, [a, -1c])) + @test isequal(toterm(-1c), _Term(Number, *, [-1, c])) + @test isequal(toterm(-1(a+b)), _Term(Number, +, [-1a, -b])) + @test isequal(toterm((a + b) - (b + c)), _Term(Number, +, [a, -1c])) end @testset "hash" begin From cf4a6c7cc758b0ac39d76a639c10804c24fa80f3 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 10 Aug 2024 23:54:39 -0700 Subject: [PATCH 092/140] Fix `term` function call due to keyword argument name change --- test/basics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basics.jl b/test/basics.jl index e71fc0ab1..10735e781 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -196,7 +196,7 @@ end @test repr(2a+1+3a^2+2b+3b^2+4a*b) == "1 + 2a + 2b + 3(a^2) + 4a*b + 3(b^2)" @syms a b[1:3] c d[1:3] - get(x, i) = term(getindex, x, i, type=Number) + get(x, i) = term(getindex, x, i; T = Number) b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2) @test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]" @test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)" From 3af7c31dfc06f321fa02eb0b8abf6e4a14bf14fc Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 11 Aug 2024 00:36:13 -0700 Subject: [PATCH 093/140] Fix `PolyForm` construction for `Const` --- src/polyform.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/polyform.jl b/src/polyform.jl index e5c2b50e8..e63ee39cb 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -95,6 +95,9 @@ end _isone(p::PolyForm) = isone(p.p) function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) + if isconst(x) + x = x.impl.val + end if x isa Number return x elseif iscall(x) From 79d19bd9a9109eee1fa691d8399cd72082f37b01 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 11 Aug 2024 00:36:39 -0700 Subject: [PATCH 094/140] Fix `Sym` constructor call in `polyize` [skip ci] --- src/polyform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/polyform.jl b/src/polyform.jl index e63ee39cb..4e85e5f26 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -132,7 +132,7 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) name = Symbol(string(op), "_", hash(y)) @label lookup - sym = Sym{symtype(x)}(name) + sym = _Sym(symtype(x), name) if haskey(sym2term, sym) if isequal(sym2term[sym][1], x) return local_polyize(sym) From 481c8c74fdee3ee60e86da6ed9c9c96b1e725a35 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 22:16:07 -0700 Subject: [PATCH 095/140] Fix getting `exp` field in `*` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index fac97bffd..324889f77 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1245,7 +1245,7 @@ function *(a::SN, b::SN) _Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif ismul(a) && ispow(b) - if b.exp isa Number + if b.impl.exp isa Number _Mul(mul_t(a, b), a.impl.coeff, _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), From 54e16afc8408d978665efb5b04387f085a5be337 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 22:47:02 -0700 Subject: [PATCH 096/140] Fix `inspect` printing for `Const` --- src/inspect.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/inspect.jl b/src/inspect.jl index ab3951725..cef102162 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -6,22 +6,23 @@ function AbstractTrees.nodevalue(x::Symbolic) end function AbstractTrees.nodevalue(x::BasicSymbolic) - str = if !iscall(x) + str = if issym(x) string(exprtype(x), "(", x, ")") + elseif isconst(x) + string(x.impl.val) elseif isadd(x) - string(exprtype(x), - (scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict))) + string(exprtype(x), + (scalar = x.impl.coeff, coeffs = Tuple(k => v for (k, v) in x.impl.dict))) elseif ismul(x) string(exprtype(x), - (scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict))) + (scalar = x.impl.coeff, powers = Tuple(k => v for (k, v) in x.impl.dict))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else - string(exprtype(x),"{", operation(x), "}") + string(exprtype(x), "{", operation(x), "}") end - if inspect_metadata[] && !isnothing(metadata(x)) - str *= string(" metadata=", Tuple(k=>v for (k, v) in metadata(x))) + str *= string(" metadata=", Tuple(k => v for (k, v) in metadata(x))) end Text(str) end From 2bd9fc035bb91045ff598eaf72503c0b8b9d8699 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 22:48:33 -0700 Subject: [PATCH 097/140] Fix getting `coeff` `dict` fields --- src/types.jl | 4 ++-- test/basics.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 324889f77..a02a31cda 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1296,9 +1296,9 @@ function ^(a::SN, b) elseif b isa Number && b < 0 _Div(1, a^(-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.coeff, b) + coeff = unstable_pow(a.impl.coeff, b) _Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b * v, a.dict)) + coeff, mapvalues((k, v) -> b * v, a.impl.dict)) else _Pow(a, b) end diff --git a/test/basics.jl b/test/basics.jl index 10735e781..a3bdecdf6 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -216,7 +216,7 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype From 8120646f6045a989d704a13105d6dbbc540b35c4 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 23:09:00 -0700 Subject: [PATCH 098/140] Fix `Div` construction when `isone(den)` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index a02a31cda..e957453ac 100644 --- a/src/types.jl +++ b/src/types.jl @@ -457,7 +457,7 @@ function _Div(::Type{T}, num, den; kwargs...) where {T} num, den = quick_cancel(num, den) end _iszero(num) && return zero(typeof(num)) - _isone(den) && return den + _isone(den) && return num if isdiv(num) && isdiv(den) return _Div(T, num.impl.num * den.impl.den, num.impl.den * den.impl.num) elseif isdiv(num) From d5a4b3c6b7e2513f9d48eea5665024a5f6c0ba47 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 23:36:30 -0700 Subject: [PATCH 099/140] Fix `quick_*` functions in polyform --- src/polyform.jl | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index 4e85e5f26..b90485e36 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -452,20 +452,29 @@ end # ispow(x) case function quick_pow(x, y) - x.exp isa Number || return (x, y) - isequal(x.base, y) && x.exp >= 1 ? (Pow{symtype(x)}(x.base, x.exp - 1),1) : (x, y) + ximpl = x.impl + if !isa(ximpl.exp, Number) + x, y + elseif isequal(ximpl.base, y) && ximpl.exp >= 1 + _Pow(symtype(x), ximpl.base, ximpl.exp - 1), 1 + else + x, y + end end # Double Pow case function quick_powpow(x, y) - if isequal(x.base, y.base) - !(x.exp isa Number && y.exp isa Number) && return (x, y) - if x.exp > y.exp - return Pow{symtype(x)}(x.base, x.exp-y.exp), 1 - elseif x.exp == y.exp + ximpl = x.impl + yimpl = y.impl + if isequal(ximpl.base, yimpl.base) + if !(ximpl.exp isa Number && yimpl.exp isa Number) + return x, y + elseif ximpl.exp > yimpl.exp + return _Pow(symtype(x), ximpl.base, ximpl.exp - yimpl.exp), 1 + elseif ximpl.exp == yimpl.exp return 1, 1 else # x.exp < y.exp - return 1, Pow{symtype(y)}(y.base, y.exp-x.exp) + return 1, _Pow(symtype(y), yimpl.base, yimpl.exp - ximpl.exp) end end return x, y @@ -473,8 +482,10 @@ end # ismul(x) function quick_mul(x, y) - if haskey(x.impl.dict, y) && x.dict[y] >= 1 - d = copy(x.dict) + ximpl = x.impl + xdict = ximpl.dict + if haskey(xdict, y) && xdict[y] >= 1 + d = copy(xdict) if d[y] > 1 d[y] -= 1 elseif d[y] == 1 @@ -482,8 +493,7 @@ function quick_mul(x, y) else error("Can't reach") end - - return Mul(symtype(x), x.coeff, d), 1 + return _Mul(symtype(x), ximpl.coeff, d), 1 else return x, y end @@ -512,8 +522,10 @@ end # Double mul case function quick_mulmul(x, y) - num_dict, den_dict = _merge_div(x.dict, y.dict) - Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict) + ximpl = x.impl + yimpl = y.impl + num_dict, den_dict = _merge_div(ximpl.dict, yimpl.dict) + _Mul(symtype(x), ximpl.coeff, num_dict), _Mul(symtype(y), yimpl.coeff, den_dict) end function _merge_div(ndict, ddict) From 0136808b6be81143bf08542da411ccc8f102106b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 23:36:39 -0700 Subject: [PATCH 100/140] Fix `Sym` construction in subtyping test --- test/basics.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index a3bdecdf6..801baea1a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -295,8 +295,8 @@ end end @testset "subtyping" begin - T = FnType{Tuple{T,S,Int} where {T,S}, Real} - s = Sym{T}(:t) + T = FnType{Tuple{T, S, Int} where {T, S}, Real} + s = _Sym(T, :t) @syms a b c::Int @test isequal(arguments(s(a, b, c)), [a, b, c]) end From 97def560b8bba4cc05625322f1ea36fe5f82bb32 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 23:40:28 -0700 Subject: [PATCH 101/140] Fix getting fields in div tests --- test/basics.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index 801baea1a..610aac7d3 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -303,23 +303,23 @@ end @testset "div" begin @syms x::SafeReal y::Real - @test issym((2x/2y).num) - @test (2x/3y).num.coeff == 2 - @test (2x/3y).den.coeff == 3 - @test (2x/-3x).num.coeff == -2 - @test (2x/-3x).den.coeff == 3 - @test (2.5x/3x).num.coeff == 2.5 - @test (2.5x/3x).den.coeff == 3 - @test (x/3x).den.coeff == 3 + @test issym((2x / 2y).impl.num) + @test (2x / 3y).impl.num.impl.coeff == 2 + @test (2x / 3y).impl.den.impl.coeff == 3 + @test (2x / -3x).impl.num.impl.coeff == -2 + @test (2x / -3x).impl.den.impl.coeff == 3 + @test (2.5x / 3x).impl.num.impl.coeff == 2.5 + @test (2.5x / 3x).impl.den.impl.coeff == 3 + @test (x / 3x).impl.den.impl.coeff == 3 @syms x y - @test issym((2x/2y).num) - @test (2x/3y).num.coeff == 2 - @test (2x/3y).den.coeff == 3 - @test (2x/-3x) == -2//3 - @test (2.5x/3x).num == 2.5 - @test (2.5x/3x).den == 3 - @test (x/3x) == 1//3 + @test issym((2x / 2y).impl.num) + @test (2x / 3y).impl.num.impl.coeff == 2 + @test (2x / 3y).impl.den.impl.coeff == 3 + @test (2x / -3x) == -2 // 3 + @test (2.5x / 3x).impl.num == 2.5 + @test (2.5x / 3x).impl.den == 3 + @test (x / 3x) == 1 // 3 @test isequal(x / 1, x) @test isequal(x / -1, -x) end From 5b5abc261001a73b8f8c03c3f021ec288aaebb9b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 12 Aug 2024 23:53:36 -0700 Subject: [PATCH 102/140] Fix `Term` construction in ordering test [skip ci] --- test/order.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/order.jl b/test/order.jl index 3b7095e95..a6e070006 100644 --- a/test/order.jl +++ b/test/order.jl @@ -27,8 +27,8 @@ end @test istotal(b*a, a) @test istotal(a, b*a) @test !(b*a <ₑ b+a) -@test Term(^, [1,-1]) <ₑ a -@test istotal(a, Term(^, [1,-1])) +@test _Term(^, [1, -1]) <ₑ a +@test istotal(a, _Term(^, [1, -1])) @testset "operator order" begin fs = (*, -, +) From 25ba87edf6506a7281c76708653d1688a2c563a3 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 15 Aug 2024 14:57:47 -0700 Subject: [PATCH 103/140] Fix ordering for `Const` --- src/ordering.jl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..56661c816 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -27,7 +27,7 @@ function get_degrees(expr) elseif iscall(expr) op = operation(expr) args = sorted_arguments(expr) - if op == (^) && args[2] isa Number + if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].impl.val isa Number)) return map(get_degrees(args[1])) do (base, pow) (base => pow * args[2]) end @@ -79,12 +79,23 @@ function <ₑ(a::Tuple, b::Tuple) end function <ₑ(a::BasicSymbolic, b::BasicSymbolic) + aisconst = isconst(a) + if aisconst + a = a.impl.val + end + bisconst = isconst(b) + if bisconst + b = b.impl.val + end + if aisconst || bisconst + return a <ₑ b + end da, db = get_degrees(a), get_degrees(b) fw = monomial_lt(da, db) bw = monomial_lt(db, da) if fw === bw && !isequal(a, b) if _arglen(a) == _arglen(b) - return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,) + return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...) else return _arglen(a) < _arglen(b) end From c87e5932f78b4b71abb73aa5b8c235ebe958a0d1 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 15 Aug 2024 15:00:57 -0700 Subject: [PATCH 104/140] Fix `Term` construction in an ordering test [skip ci] --- test/order.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/order.jl b/test/order.jl index a6e070006..0c5a32406 100644 --- a/test/order.jl +++ b/test/order.jl @@ -77,7 +77,7 @@ end @testset "small terms" begin # this failing was a cause of a nasty stackoverflow #82 @syms a - istotal(Term(^, [a, -1]), (a + 2)) + istotal(_Term(^, [a, -1]), (a + 2)) end @testset "transitivity" begin From 6849cf0a59c5f1df14b84b98b4b58b78531ae99c Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 16 Aug 2024 00:50:51 -0700 Subject: [PATCH 105/140] Fix polyform tests --- src/polyform.jl | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index b90485e36..3db1dd166 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -265,10 +265,10 @@ end function polyform_factors(d, pvar2sym, sym2term) make(xs) = map(xs) do x - if ispow(x) && x.exp isa Integer && x.exp > 0 + if ispow(x) && x.impl.exp isa Integer && x.impl.exp > 0 # here we do want to recurse one level, that's why it's wrong to just # use Fs = Union{typeof(+), typeof(*)} here. - Pow(PolyForm(x.base, pvar2sym, sym2term), x.exp) + _Pow(PolyForm(x.impl.base, pvar2sym, sym2term), x.impl.exp) else PolyForm(x, pvar2sym, sym2term) end @@ -280,13 +280,13 @@ end _mul(xs...) = all(isempty, xs) ? 1 : *(Iterators.flatten(xs)...) function simplify_div(d) - d.simplified && return d + d.impl.simplified[] && return d ns, ds = polyform_factors(d, get_pvar2sym(), get_sym2term()) ns, ds = rm_gcds(ns, ds) if all(_isone, ds) return isempty(ns) ? 1 : simplify_fractions(_mul(ns)) else - Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds))) + _Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds))) end end @@ -296,11 +296,11 @@ end #add_divs(x, y) = x + y function add_divs(x, y) if isdiv(x) && isdiv(y) - return (x.num * y.den + y.num * x.den) / (x.den * y.den) + return (x.impl.num * y.impl.den + y.impl.num * x.impl.den) / (x.impl.den * y.impl.den) elseif isdiv(x) - return (x.num + y * x.den) / x.den + return (x.impl.num + y * x.impl.den) / x.impl.den elseif isdiv(y) - return (x * y.den + y.num) / y.den + return (x * y.impl.den + y.impl.num) / y.impl.den else x + y end @@ -384,7 +384,7 @@ function fraction_isone(x) end function needs_div_rules(x) - (isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) || + (isdiv(x) && !(x.impl.num isa Number) && !(x.impl.den isa Number)) || (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) || (iscall(x) && any(needs_div_rules, arguments(x))) end @@ -416,13 +416,13 @@ But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`. Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) - if ispow(d) && isdiv(d.base) - return quick_cancel((d.base.num^d.exp) / (d.base.den^d.exp)) + if ispow(d) && isdiv(d.impl.base) + return quick_cancel((d.impl.base.impl.num^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp)) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) - num, den = quick_cancel(d.num, d.den) - return Div(num, den) + num, den = quick_cancel(d.impl.num, d.impl.den) + return _Div(num, den) else return d end @@ -501,20 +501,20 @@ end # mul, pow case function quick_mulpow(x, y) - y.exp isa Number || return (x, y) - if haskey(x.dict, y.base) - d = copy(x.dict) - if x.dict[y.base] > y.exp - d[y.base] -= y.exp + y.impl.exp isa Number || return (x, y) + if haskey(x.impl.dict, y.impl.base) + d = copy(x.impl.dict) + if x.impl.dict[y.impl.base] > y.impl.exp + d[y.impl.base] -= y.impl.exp den = 1 - elseif x.dict[y.base] == y.exp - delete!(d, y.base) + elseif x.impl.dict[y.impl.base] == y.impl.exp + delete!(d, y.impl.base) den = 1 else - den = Pow{symtype(y)}(y.base, y.exp-d[y.base]) - delete!(d, y.base) + den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base]) + delete!(d, y.impl.base) end - return Mul(symtype(x), x.coeff, d), den + return _Mul(symtype(x), x.impl.coeff, d), den else return x, y end From 5d13fa747090598666d65478a51de7476a79aa67 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 16 Aug 2024 00:51:32 -0700 Subject: [PATCH 106/140] Fix `Term` construction in `ACRule` [skip ci] --- src/rule.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule.jl b/src/rule.jl index cf78ed948..f7c8f8157 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -404,7 +404,7 @@ function (acr::ACRule)(term) itr = acr.sets(eachindex(args), acr.arity) for inds in itr - result = r(Term{T}(f, @views args[inds])) + result = r(_Term(T, f, @views args[inds])) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result From 7d2fcfd998a220d581c43f18a979b8d768f2e7a7 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 22 Aug 2024 23:50:51 -0700 Subject: [PATCH 107/140] Fix `matcher` for `Const` [skip ci] --- src/matchers.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..0af7ff4fb 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,6 +6,10 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) + if isconst(val) + slot = val.impl.val + return matcher(slot) + end iscall(val) && return term_matcher(val) function literal_matcher(next, data, bindings) islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing From 7697b81118b48b0fc9cf33015810459d136a06bd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 24 Aug 2024 00:25:36 -0700 Subject: [PATCH 108/140] Set `Term` arguments `eltype` to `Symbolic` to avoid wrapping `PolyForm` --- src/types.jl | 10 +++++++--- test/types.jl | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/types.jl b/src/types.jl index f60e1e1c3..e9109b65e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -12,7 +12,7 @@ const EMPTY_HASH = UInt(0) end struct Term f::Any - arguments::Vector{BasicSymbolic} + arguments::Vector{Symbolic} end struct Add coeff::Any @@ -358,8 +358,8 @@ function _Sym(::Type{T}, name::Symbol; kwargs...) where {T} end function _Term(::Type{T}, f, args; kwargs...) where {T} - if eltype(args) !== BasicSymbolic - args = convert(Vector{BasicSymbolic}, args) + if eltype(args) !== Symbolic + args = convert(Vector{Symbolic}, args) end impl = Term(f, args) BasicSymbolic{T}(; impl, kwargs...) @@ -373,6 +373,10 @@ function _Const(val::T; kwargs...) where {T} BasicSymbolic{T}(; impl, kwargs...) end +function Base.convert(::Type{Symbolic}, x) + _Const(x) +end + function Base.convert(::Type{BasicSymbolic}, x) _Const(x) end diff --git a/test/types.jl b/test/types.jl index 43bb1f76c..6680095a9 100644 --- a/test/types.jl +++ b/test/types.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add +using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add @testset "Expronicon generated constructors" begin s1 = Sym(:abc) @@ -22,7 +22,7 @@ using SymbolicUtils: BasicSymbolic, _Sym, _Term, _Const, _Add @test typeof(t1) == BasicSymbolicImpl @test t1.f == sin @test isequal(t1.arguments, [bs1]) - @test typeof(t1.arguments) == Vector{BasicSymbolic} + @test typeof(t1.arguments) == Vector{Symbolic} end @testset "Div" begin d1 = Div(num = bs1, den = bs2) From ed167b105d7a61184401cba2223c97e74fcfa962 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sat, 24 Aug 2024 00:26:05 -0700 Subject: [PATCH 109/140] Fix `Term` construction in `expand` test --- test/polyform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/polyform.jl b/test/polyform.jl index 5c68ddced..2b9f35b57 100644 --- a/test/polyform.jl +++ b/test/polyform.jl @@ -38,7 +38,7 @@ end @syms A::Vector{Real} # test that the following works - expand(Term{Real}(getindex, [A, 3]) - 3) + expand(_Term(Real, getindex, [A, 3]) - 3) end @testset "simplify_fractions with quick-cancel" begin From 538e8965e37188d3d55fd52ba1880dc920384d06 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 25 Aug 2024 00:17:31 -0700 Subject: [PATCH 110/140] Fix `*` for `Const` [skip ci] --- src/types.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/types.jl b/src/types.jl index e9109b65e..86d00b758 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1239,6 +1239,12 @@ mul_t(a, b) = promote_symtype(*, symtype(a), symtype(b)) mul_t(a) = promote_symtype(*, symtype(a)) function *(a::SN, b::SN) + if isconst(a) + return a.impl.val * b + end + if isconst(b) + return b.impl.val * a + end # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) From b1b855b65823ce21b50a0f5975aef1c8af722a8f Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 25 Aug 2024 00:23:31 -0700 Subject: [PATCH 111/140] Fix calling `term` --- test/rewrite.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index 3bb2621e3..310e0fcb5 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -38,9 +38,10 @@ end @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8) - @eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8]) - @eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, []) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; T = Any)) == ([9], 8) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; T = Any)) == + ([9, 8], 9, [9, 8]) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; T = Any)) == ([], 6, []) end using SymbolicUtils: @capture From 4ac8d4fedb70d5d94dd6bfec58fcb254bb0a6441 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Sun, 25 Aug 2024 00:37:09 -0700 Subject: [PATCH 112/140] Fix "Equality matching" tests due to `Term` `argument` `eltype` change [skip ci] --- test/rewrite.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/rewrite.jl b/test/rewrite.jl index 310e0fcb5..42ca9f8d9 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -38,10 +38,12 @@ end @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] - @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; T = Any)) == ([9], 8) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; T = Any)) == + (Symbolic[9], _Const(8)) @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; T = Any)) == - ([9, 8], 9, [9, 8]) - @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; T = Any)) == ([], 6, []) + (Symbolic[9, 8], _Const(9), Symbolic[9, 8]) + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; T = Any)) == + (Symbolic[], _Const(6), Symbolic[]) end using SymbolicUtils: @capture From d1140d9b0d31f5e94049a400823918fa6f9fde11 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 26 Aug 2024 19:39:40 -0700 Subject: [PATCH 113/140] Fix `Term` construction in "Numeric" testset [skip ci] --- test/rulesets.jl | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/test/rulesets.jl b/test/rulesets.jl index 4730cc102..574ed46b6 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -18,10 +18,10 @@ end @testset "Numeric" begin @syms a::Integer b c d x::Real y::Number - @eqtest simplify(Term{Real}(conj, [x])) == x - @eqtest simplify(Term{Real}(real, [x])) == x - @eqtest simplify(Term{Real}(imag, [x])) == 0 - @eqtest simplify(Term{Real}(imag, [y])) == imag(y) + @eqtest simplify(_Term(Real, conj, [x])) == x + @eqtest simplify(_Term(Real, real, [x])) == x + @eqtest simplify(_Term(Real, imag, [x])) == 0 + @eqtest simplify(_Term(Real, imag, [y])) == imag(y) @eqtest simplify(x - y) == x + -1 * y @eqtest simplify(x - sin(y)) == x + -1 * sin(y) @eqtest simplify(-sin(x)) == -1 * sin(x) @@ -44,14 +44,13 @@ end @eqtest simplify(a * b * 1 * c * d) == simplify(a * b * c * d) @eqtest simplify_fractions(x^2.0 / (x * y)^2.0) == simplify_fractions(1 / (y^2.0)) - @test simplify(Term(one, [a])) == 1 - @test simplify(Term(one, [b + 1])) == 1 - @test simplify(Term(one, [x + 2])) == 1 + @test simplify(_Term(one, [a])) == 1 + @test simplify(_Term(one, [b + 1])) == 1 + @test simplify(_Term(one, [x + 2])) == 1 - - @test simplify(Term(zero, [a])) == 0 - @test simplify(Term(zero, [b + 1])) == 0 - @test simplify(Term(zero, [x + 2])) == 0 + @test simplify(_Term(zero, [a])) == 0 + @test simplify(_Term(zero, [b + 1])) == 0 + @test simplify(_Term(zero, [x + 2])) == 0 end @testset "LiteralReal" begin From a024aa13c4eeca83c04ab531e0fbbd4dd4f8ee8b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 27 Aug 2024 23:46:08 -0700 Subject: [PATCH 114/140] Fix `is_literal_number` for `Const` --- src/utils.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 438a6ad7c..2da9dba69 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -64,8 +64,12 @@ end sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T -isliteral(::Type{T}) where {T} = x -> x isa T -is_literal_number(x) = isliteral(Number)(x) +function is_literal_number(x) + if isconst(x) + x = x.impl.val + end + x isa Number +end # checking the type directly is faster than dynamic dispatch in type unstable code _iszero(x) = x isa Number && iszero(x) From 76d7b8e948f66c61a9773dbc03303c8f5518aaa2 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 27 Aug 2024 23:46:24 -0700 Subject: [PATCH 115/140] Fix `*` for `Const` [skip ci] --- src/types.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/types.jl b/src/types.jl index 2b6a0e44e..032d9d0eb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1276,6 +1276,9 @@ function *(a::SN, b::SN) end end function *(a::Number, b::SN) + if isconst(b) + return a * b.impl.val + end !issafecanon(*, b) && return term(*, a, b) if iszero(a) a From d32622aef78370195a4c9e2a20d77172331f2243 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 27 Aug 2024 23:51:52 -0700 Subject: [PATCH 116/140] Fix `_Div` constructor call in `*` [skip ci] --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 032d9d0eb..de061086f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1285,7 +1285,7 @@ function *(a::Number, b::SN) elseif isone(a) b elseif isdiv(b) - Div(a * b.impl.num, b.impl.den) + _Div(a * b.impl.num, b.impl.den) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) From 0e9c845f3c6f9130a0c1cbdad577fb1c27c6b535 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 02:32:18 -0700 Subject: [PATCH 117/140] Fix `literal_matcher` for `Const` [skip ci] --- src/matchers.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/matchers.jl b/src/matchers.jl index 0af7ff4fb..5eb1df4d4 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -9,10 +9,20 @@ function matcher(val::Any) if isconst(val) slot = val.impl.val return matcher(slot) + elseif iscall(val) + return term_matcher(val) end - iscall(val) && return term_matcher(val) function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing + if islist(data) + cd = car(data) + if isconst(cd) + cd = cd.impl.val + end + if isequal(cd, val) + return next(bindings, 1) + end + end + nothing end end From 678507c5be458e303a6c5adfa52bd2f17a715c01 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 13:45:54 -0700 Subject: [PATCH 118/140] Fix `Term` construction in "boolean" tests --- test/rulesets.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets.jl b/test/rulesets.jl index 574ed46b6..1e6a73312 100644 --- a/test/rulesets.jl +++ b/test/rulesets.jl @@ -76,8 +76,8 @@ end @eqtest simplify(true & (0 < a)) == (0 < a) @eqtest simplify(false & (0 < a)) == false @eqtest simplify((0 < a) & false) == false - @eqtest simplify(Term{Bool}(!, [true])) == false - @eqtest simplify(Term{Bool}(|, [false, true])) == true + @eqtest simplify(_Term(Bool, !, [true])) == false + @eqtest simplify(_Term(Bool, |, [false, true])) == true @eqtest simplify(ifelse(true, a, b)) == a @eqtest simplify(ifelse(false, a, b)) == b From 7ddad9c31f639046bb79c6963bcccde457eeff8d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 13:46:26 -0700 Subject: [PATCH 119/140] Fix `!` for `Const` [skip ci] --- src/methods.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/methods.jl b/src/methods.jl index 8ed59649e..e31f1f6be 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -186,7 +186,13 @@ end for f in [!, ~] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool - (::$(typeof(f)))(s::Symbolic{Bool}) = _Term(Bool, !, [s]) + function (::$(typeof(f)))(s::Symbolic{Bool}) + if isconst(s) + s = s.impl.val + return !s + end + _Term(Bool, !, [s]) + end end end From ca4be5368144211ac04220bb366cfeaf6effe5cc Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 14:37:31 -0700 Subject: [PATCH 120/140] Unwrap `Const` in `substitute` [skip ci] --- src/substitute.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/substitute.jl b/src/substitute.jl index 8fc980c69..883496e0a 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -22,6 +22,9 @@ function substitute(expr, dict; fold=true) canfold = !(op isa Symbolic) args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) + if isconst(x) + x′ = x′.impl.val + end canfold = canfold && !(x′ isa Symbolic) x′ end From 768f1ce69af20c80814dbf8f589618051c887e4f Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 21:47:56 -0700 Subject: [PATCH 121/140] Fix `toexpr` for `Const` [skip ci] --- src/code.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index 4128a39fd..458fd5cc2 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, sorted_arguments, metadata, isterm, term, maketerm + isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -182,6 +182,8 @@ function toexpr(O, st) if issym(O) O = substitute_name(O, st) return issym(O) ? nameof(O) : toexpr(O, st) + elseif isconst(O) + return O.impl.val end O = substitute_name(O, st) From 497a7174f5e96acead0230b3fdc2285168c132ce Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 21:52:20 -0700 Subject: [PATCH 122/140] Fix `Term` construction in "Code" tests --- test/code.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/code.jl b/test/code.jl index 7956aa59b..d2609d975 100644 --- a/test/code.jl +++ b/test/code.jl @@ -214,7 +214,7 @@ nanmath_st.rewrites[:nanmath] = true for q ∈ Base.Irrational[Base.MathConstants.catalan, Base.MathConstants.γ, π, Base.MathConstants.φ, ℯ, twoπ] Base.show(io, q) s1 = String(take!(io)) - SymbolicUtils.show_term(io, SymbolicUtils.Term(identity, [q])) + SymbolicUtils.show_term(io, SymbolicUtils._Term(identity, [q])) s2 = String(take!(io)) @test s1 == s2 end From 835aab49c80feaa3da41c7f7a8ed122b35efbec5 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 29 Aug 2024 22:02:16 -0700 Subject: [PATCH 123/140] Revert keyword argument `T` back to `type` in `term` to avoid breaking [skip ci] --- src/methods.jl | 6 +++--- src/rule.jl | 4 ++-- src/types.jl | 8 ++++---- test/basics.jl | 6 +++--- test/rewrite.jl | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/methods.jl b/src/methods.jl index f15566675..7e5850a4c 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -179,9 +179,9 @@ for (f, Domain) in [(==) => Number, (!=) => Number, xor => Bool] @eval begin promote_symtype(::$(typeof(f)), ::Type{<:$Domain}, ::Type{<:$Domain}) = Bool - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b; T = Bool) - (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b; T = Bool) - (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b; T = Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::$Domain) = term($f, a, b; type = Bool) + (::$(typeof(f)))(a::Symbolic{<:$Domain}, b::Symbolic{<:$Domain}) = term($f, a, b; type = Bool) + (::$(typeof(f)))(a::$Domain, b::Symbolic{<:$Domain}) = term($f, a, b; type = Bool) end end diff --git a/src/rule.jl b/src/rule.jl index aca5b8456..7a6353a6b 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -67,10 +67,10 @@ function makepattern(expr, keys) makeslot(expr.args[2], keys) end else - :(term($(map(x -> makepattern(x, keys), expr.args)...); T = Any)) + :(term($(map(x -> makepattern(x, keys), expr.args)...); type = Any)) end elseif expr.head === :ref - :(term(getindex, $(map(x -> makepattern(x, keys), expr.args)...); T = Any)) + :(term(getindex, $(map(x -> makepattern(x, keys), expr.args)...); type = Any)) elseif expr.head === :$ return esc(expr.args[1]) else diff --git a/src/types.jl b/src/types.jl index de061086f..5ff959066 100644 --- a/src/types.jl +++ b/src/types.jl @@ -609,11 +609,11 @@ function makepow(a, b) base, exp end -function term(f, args...; T = nothing) - if T === nothing - T = _promote_symtype(f, args) +function term(f, args...; type = nothing) + if type === nothing + type = _promote_symtype(f, args) end - _Term(T, f, [args...]) + _Term(type, f, [args...]) end """ diff --git a/test/basics.jl b/test/basics.jl index 2cd34b607..2e6b6b256 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -91,7 +91,7 @@ struct Ctx2 end @test isequal(substitute(1+sqrt(a), Dict(a => 2), fold=false), - 1 + term(sqrt, 2, T=Number)) + 1 + term(sqrt, 2, type = Number)) @test substitute(1+sqrt(a), Dict(a => 2), fold=true) isa Float64 end @@ -173,7 +173,7 @@ end @syms a b c @test repr(a+b) == "a + b" @test repr(-a) == "-a" - @test repr(term(-, a; T = Real)) == "-a" + @test repr(term(-, a; type = Real)) == "-a" @test repr(-a + 3) == "3 - a" @test repr(-(a + b)) == "-a - b" @test repr((2a)^(-2a)) == "(2a)^(-2a)" @@ -196,7 +196,7 @@ end @test repr(2a+1+3a^2+2b+3b^2+4a*b) == "1 + 2a + 2b + 3(a^2) + 4a*b + 3(b^2)" @syms a b[1:3] c d[1:3] - get(x, i) = term(getindex, x, i; T = Number) + get(x, i) = term(getindex, x, i; type = Number) b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2) @test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]" @test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)" diff --git a/test/rewrite.jl b/test/rewrite.jl index 42ca9f8d9..243823379 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -38,11 +38,11 @@ end @eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b] @eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c] - @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; T = Any)) == + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y))(term(+, 9, 8, 9; type = Any)) == (Symbolic[9], _Const(8)) - @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; T = Any)) == + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 9, 8, 9, 9, 8; type = Any)) == (Symbolic[9, 8], _Const(9), Symbolic[9, 8]) - @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; T = Any)) == + @eqtest @rule(+(~~x, ~y, ~~x)=>(~~x, ~y, ~~x))(term(+, 6; type = Any)) == (Symbolic[], _Const(6), Symbolic[]) end From ad7a8614248b5759619669e9bcacf324e11762ed Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Tue, 10 Sep 2024 20:22:05 -0400 Subject: [PATCH 124/140] Fix `Expr` generation for `Const` [skip ci] --- src/code.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index 458fd5cc2..424f4bc1d 100644 --- a/src/code.jl +++ b/src/code.jl @@ -183,7 +183,7 @@ function toexpr(O, st) O = substitute_name(O, st) return issym(O) ? nameof(O) : toexpr(O, st) elseif isconst(O) - return O.impl.val + return toexpr(O.impl.val, st) end O = substitute_name(O, st) From 0f072b41c42e1fb0ead3efab6030b1fc6b07dd6e Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 11 Sep 2024 13:57:33 -0400 Subject: [PATCH 125/140] Fix `Sym` construction in CSE [skip ci] --- src/code.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/code.jl b/src/code.jl index 424f4bc1d..113a6dcb2 100644 --- a/src/code.jl +++ b/src/code.jl @@ -8,7 +8,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters -import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, +import SymbolicUtils: @matchable, BasicSymbolic, _Sym, Term, iscall, operation, arguments, issym, isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm import SymbolicIndexingInterface: symbolic_type, NotSymbolic @@ -683,7 +683,7 @@ end ### Common subexprssion evaluation -@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) +@inline newsym(::Type{T}) where T = _Sym(T, gensym("cse")) function _cse!(mem, expr) iscall(expr) || return expr @@ -747,7 +747,7 @@ function cse_block!(assignments, counter, names, name, state, x) if haskey(names, x) return names[x] else - sym = Sym{symtype(x)}(Symbol(name, counter[])) + sym = _Sym(symtype(x), Symbol(name, counter[])) names[x] = sym push!(assignments, sym ← x) counter[] += 1 From 8870978fb6592545f7b6e6eb9e0c971d470dd555 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 12:16:03 -0400 Subject: [PATCH 126/140] Fix `Sym` getting name --- test/basics.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/basics.jl b/test/basics.jl index 51f25d150..7b78da35f 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -37,14 +37,14 @@ using Test @syms (f::typeof(max))(::Real, ::AbstractFloat)::Number a::Real @test issym(f) - @test f.name == :f + @test f.impl.name == :f @test symtype(f) == FnType{Tuple{Real, AbstractFloat}, Number, typeof(max)} @test isterm(f(a, b)) @test symtype(f(a, b)) == Number @syms g(p, (h::typeof(identity))(q::Real)::Number)::Number @test issym(g) - @test g.name == :g + @test g.impl.name == :g @test symtype(g) == FnType{Tuple{Number, FnType{Tuple{Real}, Number, typeof(identity)}}, Number, Nothing} @test_throws "not a subtype of" g(a, f) @syms (f::typeof(identity))(::Real)::Number From 8d18d9dca6f4abe5464dbebce79bb1ae81cd41ce Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 15:54:05 -0400 Subject: [PATCH 127/140] Fix `length` for `BasicSymbolic` --- src/utils.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 2da9dba69..4508acad2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -183,7 +183,12 @@ Base.length(l::LL) = length(l.v)-l.i+1 @inline car(l::LL) = l.v[l.i] @inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i+1) -Base.length(t::BasicSymbolic) = length(arguments(t)) + 1 # PIRACY +function Base.length(t::BasicSymbolic) + @match t.impl begin + Term(_...) => length(arguments(t)) + 1 # PIRACY + _ => 1 + end +end Base.isempty(t::BasicSymbolic) = false @inline car(t::BasicSymbolic) = operation(t) @inline cdr(t::BasicSymbolic) = arguments(t) From 85efc3137f35099fe6f7b5f965dfd486453e54cd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 16:01:47 -0400 Subject: [PATCH 128/140] Fix `-(a)` for `Const` --- src/types.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 73d7319a2..dec985163 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1231,9 +1231,19 @@ end +(a::SN) = a function -(a::SN) - !issafecanon(*, a) && return term(-, a) - isadd(a) ? _Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict)) : - _Add(sub_t(a), makeadd(-1, 0, a)...) + if isconst(a) + v = a.impl.val + mv = -v + return _Const(mv) + end + if !issafecanon(*, a) + return term(-, a) + end + if isadd(a) + _Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict)) + else + _Add(sub_t(a), makeadd(-1, 0, a)...) + end end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) From c3b058535aecdd6490d36a7b3b12006394760cf7 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 16:02:00 -0400 Subject: [PATCH 129/140] Fix `show_mul` for `Const` --- src/types.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/types.jl b/src/types.jl index dec985163..bc7bc8de3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -876,6 +876,10 @@ function show_pow(io, args) end function show_mul(io, args) + if isconst(args) + print(io, args.impl.val) + return + end length(args) == 1 && return print_arg(io, *, args[1]) arg1 = args[1] From 1660338762e41897bb3d3b925c80af7ab1887567 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Mon, 16 Sep 2024 16:17:29 -0400 Subject: [PATCH 130/140] Fix `Term` construction in `gen_expr` --- test/fuzzlib.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/fuzzlib.jl b/test/fuzzlib.jl index ae9d9a213..aa0a51286 100644 --- a/test/fuzzlib.jl +++ b/test/fuzzlib.jl @@ -207,7 +207,7 @@ function gen_expr(lvl=5) n = rand(1:5) args = [gen_expr(lvl-1) for i in 1:n] - Term{Number}(f, first.(args)), f(last.(args)...) + _Term(Number, f, first.(args)), f(last.(args)...) else f = rand((-,/)) l = gen_expr(lvl-1) @@ -217,7 +217,7 @@ function gen_expr(lvl=5) end args = [l, r] - Term{Number}(f, first.(args)), f(last.(args)...) + _Term(Number, f, first.(args)), f(last.(args)...) end end From 203afabf5925b7ac9267e5ec212eade1a191aa96 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 18 Sep 2024 13:41:19 -0400 Subject: [PATCH 131/140] Define `get_name` function --- src/types.jl | 14 ++++++++++++-- test/basics.jl | 12 ++++++------ test/types.jl | 6 +++--- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/types.jl b/src/types.jl index bc7bc8de3..b8e0e0a56 100644 --- a/src/types.jl +++ b/src/types.jl @@ -64,6 +64,10 @@ function exprtype(x::BasicSymbolic) end end +function get_name(x::BasicSymbolic) + x.impl.name +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -308,7 +312,13 @@ end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.impl.name : error("None Sym BasicSymbolic doesn't have a name") +function Base.nameof(s::BasicSymbolic) + if issym(s) + get_name(s) + else + error("None Sym BasicSymbolic doesn't have a name") + end +end ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) @@ -986,7 +996,7 @@ showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) @match v.impl begin - Sym(_...) => Base.show_unquoted(io, v.impl.name) + Sym(_...) => Base.show_unquoted(io, get_name(v)) Const(_...) => print(io, v.impl.val) _ => show_term(io, v) end diff --git a/test/basics.jl b/test/basics.jl index 7b78da35f..019acfb2a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, BasicSymbolic, term +using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, get_name using SymbolicUtils using IfElse: ifelse using Setfield @@ -9,17 +9,17 @@ using Test @syms a b::Float64 f(::Real) g(p, h(q::Real))::Int @test issym(a) && symtype(a) == Number - @test a.impl.name === :a + @test get_name(a) === :a @test issym(b) && symtype(b) == Float64 @test nameof(b) === :b @test issym(f) - @test f.impl.name === :f + @test get_name(f) === :f @test symtype(f) == FnType{Tuple{Real}, Number, Nothing} @test issym(g) - @test g.impl.name === :g + @test get_name(g) === :g @test symtype(g) == FnType{Tuple{Number, FnType{Tuple{Real}, Number, Nothing}}, Int, Nothing} @test isterm(f(b)) @@ -37,14 +37,14 @@ using Test @syms (f::typeof(max))(::Real, ::AbstractFloat)::Number a::Real @test issym(f) - @test f.impl.name == :f + @test get_name(f) == :f @test symtype(f) == FnType{Tuple{Real, AbstractFloat}, Number, typeof(max)} @test isterm(f(a, b)) @test symtype(f(a, b)) == Number @syms g(p, (h::typeof(identity))(q::Real)::Number)::Number @test issym(g) - @test g.impl.name == :g + @test get_name(g) == :g @test symtype(g) == FnType{Tuple{Number, FnType{Tuple{Real}, Number, typeof(identity)}}, Number, Nothing} @test_throws "not a subtype of" g(a, f) @syms (f::typeof(identity))(::Real)::Number diff --git a/test/types.jl b/test/types.jl index 6680095a9..b2d7c345e 100644 --- a/test/types.jl +++ b/test/types.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add +using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name @testset "Expronicon generated constructors" begin s1 = Sym(:abc) @@ -112,11 +112,11 @@ end @test typeof(s1) == BasicSymbolic{Int64} @test s1.metadata == SymbolicUtils.NO_METADATA @test s1.hash[] == SymbolicUtils.EMPTY_HASH - @test s1.impl.name == :x + @test get_name(s1) == :x @test typeof(s2) == BasicSymbolic{Float64} @test s2.metadata == SymbolicUtils.NO_METADATA @test s2.hash[] == SymbolicUtils.EMPTY_HASH - @test s2.impl.name == :y + @test get_name(s2) == :y end @testset "Term" begin s1 = _Sym(Float64, :x) From 206d21c80aaa513ebdbf5f5f8bd3d643fa8a83c6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 18 Sep 2024 13:45:56 -0400 Subject: [PATCH 132/140] Define `get_coeff` function --- src/inspect.jl | 4 ++-- src/polyform.jl | 2 +- src/types.jl | 40 ++++++++++++++++++++++------------------ test/basics.jl | 21 +++++++++++---------- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/inspect.jl b/src/inspect.jl index cef102162..6f22d0ad6 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) string(x.impl.val) elseif isadd(x) string(exprtype(x), - (scalar = x.impl.coeff, coeffs = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict))) elseif ismul(x) string(exprtype(x), - (scalar = x.impl.coeff, powers = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else diff --git a/src/polyform.jl b/src/polyform.jl index 3db1dd166..34989eceb 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -514,7 +514,7 @@ function quick_mulpow(x, y) den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base]) delete!(d, y.impl.base) end - return _Mul(symtype(x), x.impl.coeff, d), den + return _Mul(symtype(x), get_coeff(x), d), den else return x, y end diff --git a/src/types.jl b/src/types.jl index b8e0e0a56..8cae72514 100644 --- a/src/types.jl +++ b/src/types.jl @@ -68,6 +68,10 @@ function get_name(x::BasicSymbolic) x.impl.name end +function get_coeff(x::BasicSymbolic) + x.impl.coeff +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -293,7 +297,7 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict) + coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(a.impl.dict, b.impl.dict) elseif E === DIV isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW @@ -337,7 +341,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt))) + h′ = hash(hashoffset, hash(get_coeff(s), hash(s.impl.dict, salt))) s.hash[] = h′ return h′ elseif E === DIV @@ -444,7 +448,7 @@ const Rat = Union{Rational, Integer} function ratcoeff(x) if ismul(x) - ratcoeff(x.impl.coeff) + ratcoeff(get_coeff(x)) elseif x isa Rat (true, x) else @@ -455,7 +459,7 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y ratio(x::Rat,y::Rat) = x//y function maybe_intcoeff(x) if ismul(x) - coeff = x.impl.coeff + coeff = get_coeff(x) if coeff isa Rational && isone(denominator(coeff)) _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) else @@ -537,7 +541,7 @@ function toterm(t::BasicSymbolic{T}) where {T} return t elseif E === ADD || E === MUL args = BasicSymbolic[] - push!(args, t.impl.coeff) + push!(args, get_coeff(t)) for (k, coeff) in t.impl.dict push!( args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) @@ -562,7 +566,7 @@ function makeadd(sign, coeff, xs...) d = Dict{BasicSymbolic, Any}() for x in xs if isadd(x) - coeff += x.impl.coeff + coeff += get_coeff(x) _merge!(+, d, x.impl.dict, filter = _iszero) continue end @@ -572,7 +576,7 @@ function makeadd(sign, coeff, xs...) end if ismul(x) k = _Mul(symtype(x), 1, x.impl.dict) - v = sign * x.impl.coeff + get(d, k, 0) + v = sign * get_coeff(x) + get(d, k, 0) else k = x v = sign + get(d, x, 0) @@ -593,7 +597,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) elseif x isa Number coeff *= x elseif ismul(x) - coeff *= x.impl.coeff + coeff *= get_coeff(x) _merge!(+, d, x.impl.dict, filter = _iszero) else v = 1 + get(d, x, 0) @@ -1219,10 +1223,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), a.impl.coeff + b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), a.impl.coeff + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) elseif isadd(b) return b + a end @@ -1236,7 +1240,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + b.impl.coeff, b.impl.dict) + _Add(add_t(a, b), a + get_coeff(b), b.impl.dict) else _Add(add_t(a, b), makeadd(1, a, b)...) end @@ -1254,7 +1258,7 @@ function -(a::SN) return term(-, a) end if isadd(a) - _Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict)) + _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, a.impl.dict)) else _Add(sub_t(a), makeadd(-1, 0, a)...) end @@ -1262,7 +1266,7 @@ end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), a.impl.coeff - b.impl.coeff, _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) + _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) else a + (-b) end @@ -1289,16 +1293,16 @@ function *(a::SN, b::SN) elseif isdiv(b) _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) - _Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff, + _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) elseif ismul(a) && ispow(b) if b.impl.exp isa Number _Mul(mul_t(a, b), - a.impl.coeff, + get_coeff(a), _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), filter = _iszero)) else - _Mul(mul_t(a, b), a.impl.coeff, + _Mul(mul_t(a, b), get_coeff(a), _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) @@ -1321,7 +1325,7 @@ function *(a::Number, b::SN) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - _Add(T, b.impl.coeff * a, + _Add(T, get_coeff(b) * a, Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) else _Mul(mul_t(a, b), makemul(a, b)...) @@ -1346,7 +1350,7 @@ function ^(a::SN, b) elseif b isa Number && b < 0 _Div(1, a^(-b)) elseif ismul(a) && b isa Number - coeff = unstable_pow(a.impl.coeff, b) + coeff = unstable_pow(get_coeff(a), b) _Mul(promote_symtype(^, symtype(a), symtype(b)), coeff, mapvalues((k, v) -> b * v, a.impl.dict)) else diff --git a/test/basics.jl b/test/basics.jl index 019acfb2a..9a70f23b6 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,5 @@ -using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, get_name +using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, + BasicSymbolic, term, get_name, get_coeff using SymbolicUtils using IfElse: ifelse using Setfield @@ -344,18 +345,18 @@ end @testset "div" begin @syms x::SafeReal y::Real @test issym((2x / 2y).impl.num) - @test (2x / 3y).impl.num.impl.coeff == 2 - @test (2x / 3y).impl.den.impl.coeff == 3 - @test (2x / -3x).impl.num.impl.coeff == -2 - @test (2x / -3x).impl.den.impl.coeff == 3 - @test (2.5x / 3x).impl.num.impl.coeff == 2.5 - @test (2.5x / 3x).impl.den.impl.coeff == 3 - @test (x / 3x).impl.den.impl.coeff == 3 + @test get_coeff((2x / 3y).impl.num) == 2 + @test get_coeff((2x / 3y).impl.den) == 3 + @test get_coeff((2x / -3x).impl.num) == -2 + @test get_coeff((2x / -3x).impl.den) == 3 + @test get_coeff((2.5x / 3x).impl.num) == 2.5 + @test get_coeff((2.5x / 3x).impl.den) == 3 + @test get_coeff((x / 3x).impl.den) == 3 @syms x y @test issym((2x / 2y).impl.num) - @test (2x / 3y).impl.num.impl.coeff == 2 - @test (2x / 3y).impl.den.impl.coeff == 3 + @test get_coeff((2x / 3y).impl.num) == 2 + @test get_coeff((2x / 3y).impl.den) == 3 @test (2x / -3x) == -2 // 3 @test (2.5x / 3x).impl.num == 2.5 @test (2.5x / 3x).impl.den == 3 From 53ca3b82051a92f80783a8572b2b56614a53f8bb Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 18 Sep 2024 19:14:59 -0400 Subject: [PATCH 133/140] Define `get_dict` function --- src/inspect.jl | 4 ++-- src/polyform.jl | 8 ++++---- src/types.jl | 38 +++++++++++++++++++++----------------- test/basics.jl | 4 ++-- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/inspect.jl b/src/inspect.jl index 6f22d0ad6..8dc7eaf6c 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) string(x.impl.val) elseif isadd(x) string(exprtype(x), - (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x)))) elseif ismul(x) string(exprtype(x), - (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict))) + (scalar = get_coeff(x), powers = Tuple(k => v for (k, v) in get_dict(x)))) elseif isdiv(x) || ispow(x) string(exprtype(x)) else diff --git a/src/polyform.jl b/src/polyform.jl index 34989eceb..f2630166e 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -502,12 +502,12 @@ end # mul, pow case function quick_mulpow(x, y) y.impl.exp isa Number || return (x, y) - if haskey(x.impl.dict, y.impl.base) - d = copy(x.impl.dict) - if x.impl.dict[y.impl.base] > y.impl.exp + if haskey(get_dict(x), y.impl.base) + d = copy(get_dict(x)) + if get_dict(x)[y.impl.base] > y.impl.exp d[y.impl.base] -= y.impl.exp den = 1 - elseif x.impl.dict[y.impl.base] == y.impl.exp + elseif get_dict(x)[y.impl.base] == y.impl.exp delete!(d, y.impl.base) den = 1 else diff --git a/src/types.jl b/src/types.jl index 8cae72514..bfbc20a9c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -72,6 +72,10 @@ function get_coeff(x::BasicSymbolic) x.impl.coeff end +function get_dict(x::BasicSymbolic) + x.impl.dict +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -297,7 +301,7 @@ function _isequal(a, b, E) if E === SYM nameof(a) === nameof(b) elseif E === ADD || E === MUL - coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(a.impl.dict, b.impl.dict) + coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b)) elseif E === DIV isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) elseif E === POW @@ -341,7 +345,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h = s.hash[] !iszero(h) && return h hashoffset = isadd(s) ? ADD_SALT : SUB_SALT - h′ = hash(hashoffset, hash(get_coeff(s), hash(s.impl.dict, salt))) + h′ = hash(hashoffset, hash(get_coeff(s), hash(get_dict(s), salt))) s.hash[] = h′ return h′ elseif E === DIV @@ -461,7 +465,7 @@ function maybe_intcoeff(x) if ismul(x) coeff = get_coeff(x) if coeff isa Rational && isone(denominator(coeff)) - _Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata) + _Mul(symtype(x), coeff.num, get_dict(x); metadata = x.metadata) else x end @@ -542,7 +546,7 @@ function toterm(t::BasicSymbolic{T}) where {T} elseif E === ADD || E === MUL args = BasicSymbolic[] push!(args, get_coeff(t)) - for (k, coeff) in t.impl.dict + for (k, coeff) in get_dict(t) push!( args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k])) end @@ -567,7 +571,7 @@ function makeadd(sign, coeff, xs...) for x in xs if isadd(x) coeff += get_coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, get_dict(x), filter = _iszero) continue end if x isa Number @@ -575,7 +579,7 @@ function makeadd(sign, coeff, xs...) continue end if ismul(x) - k = _Mul(symtype(x), 1, x.impl.dict) + k = _Mul(symtype(x), 1, get_dict(x)) v = sign * get_coeff(x) + get(d, k, 0) else k = x @@ -598,7 +602,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) coeff *= x elseif ismul(x) coeff *= get_coeff(x) - _merge!(+, d, x.impl.dict, filter = _iszero) + _merge!(+, d, get_dict(x), filter = _iszero) else v = 1 + get(d, x, 0) if _iszero(v) @@ -1223,10 +1227,10 @@ function +(a::SN, b::SN) !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return _Add( - add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + add_t(a, b), get_coeff(a) + get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif isadd(a) coeff, dict = makeadd(1, 0, b) - return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero)) + return _Add(add_t(a, b), get_coeff(a) + coeff, _merge(+, get_dict(a), dict, filter = _iszero)) elseif isadd(b) return b + a end @@ -1240,7 +1244,7 @@ function +(a::Number, b::SN) !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) - _Add(add_t(a, b), a + get_coeff(b), b.impl.dict) + _Add(add_t(a, b), a + get_coeff(b), get_dict(b)) else _Add(add_t(a, b), makeadd(1, a, b)...) end @@ -1258,7 +1262,7 @@ function -(a::SN) return term(-, a) end if isadd(a) - _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, a.impl.dict)) + _Add(sub_t(a), -get_coeff(a), mapvalues((_, v) -> -v, get_dict(a))) else _Add(sub_t(a), makeadd(-1, 0, a)...) end @@ -1266,7 +1270,7 @@ end function -(a::SN, b::SN) (!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b) if isadd(a) && isadd(b) - _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero)) + _Add(sub_t(a, b), get_coeff(a) - get_coeff(b), _merge(-, get_dict(a), get_dict(b), filter = _iszero)) else a + (-b) end @@ -1294,16 +1298,16 @@ function *(a::SN, b::SN) _Div(a * b.impl.num, b.impl.den) elseif ismul(a) && ismul(b) _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), - _merge(+, a.impl.dict, b.impl.dict, filter = _iszero)) + _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif ismul(a) && ispow(b) if b.impl.exp isa Number _Mul(mul_t(a, b), get_coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp), + _merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp), filter = _iszero)) else _Mul(mul_t(a, b), get_coeff(a), - _merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero)) + _merge(+, get_dict(a), Base.ImmutableDict(b => 1), filter = _iszero)) end elseif ispow(a) && ismul(b) b * a @@ -1326,7 +1330,7 @@ function *(a::Number, b::SN) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) _Add(T, get_coeff(b) * a, - Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict)) + Dict{BasicSymbolic, Any}(k => v * a for (k, v) in get_dict(b))) else _Mul(mul_t(a, b), makemul(a, b)...) end @@ -1352,7 +1356,7 @@ function ^(a::SN, b) elseif ismul(a) && b isa Number coeff = unstable_pow(get_coeff(a), b) _Mul(promote_symtype(^, symtype(a), symtype(b)), - coeff, mapvalues((k, v) -> b * v, a.impl.dict)) + coeff, mapvalues((k, v) -> b * v, get_dict(a))) else _Pow(a, b) end diff --git a/test/basics.jl b/test/basics.jl index 9a70f23b6..a392eddae 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,5 +1,5 @@ using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, - BasicSymbolic, term, get_name, get_coeff + BasicSymbolic, term, get_name, get_coeff, get_dict using SymbolicUtils using IfElse: ifelse using Setfield @@ -234,7 +234,7 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).impl.dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(get_dict(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing)), Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype From 546b00050439f75daaf26941b6de77f3e9f5645d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Wed, 18 Sep 2024 20:11:27 -0400 Subject: [PATCH 134/140] Fix `_occursin` for `Const` --- src/substitute.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/substitute.jl b/src/substitute.jl index 54354f31b..025c9dbcc 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -57,10 +57,13 @@ function _occursin(needle, haystack) if iscall(haystack) args = arguments(haystack) for arg in args + if isconst(arg) + arg = arg.impl.val + end if needle isa Integer || needle isa AbstractFloat isequal(needle, arg) && return true else - occursin(needle, arg) && return true + occursin(needle, arg) && return true end end end From e011b04ecb244f46e375f283bb6a52f00202c687 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 12:30:11 -0400 Subject: [PATCH 135/140] Define `get_num` --- src/polyform.jl | 12 ++++++------ src/types.jl | 26 +++++++++++++++----------- test/basics.jl | 16 ++++++++-------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index f2630166e..d5d12e5b7 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -296,11 +296,11 @@ end #add_divs(x, y) = x + y function add_divs(x, y) if isdiv(x) && isdiv(y) - return (x.impl.num * y.impl.den + y.impl.num * x.impl.den) / (x.impl.den * y.impl.den) + return (get_num(x) * y.impl.den + get_num(y) * x.impl.den) / (x.impl.den * y.impl.den) elseif isdiv(x) - return (x.impl.num + y * x.impl.den) / x.impl.den + return (get_num(x) + y * x.impl.den) / x.impl.den elseif isdiv(y) - return (x * y.impl.den + y.impl.num) / y.impl.den + return (x * y.impl.den + get_num(y)) / y.impl.den else x + y end @@ -384,7 +384,7 @@ function fraction_isone(x) end function needs_div_rules(x) - (isdiv(x) && !(x.impl.num isa Number) && !(x.impl.den isa Number)) || + (isdiv(x) && !(get_num(x) isa Number) && !(x.impl.den isa Number)) || (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) || (iscall(x) && any(needs_div_rules, arguments(x))) end @@ -417,11 +417,11 @@ Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) if ispow(d) && isdiv(d.impl.base) - return quick_cancel((d.impl.base.impl.num^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp)) + return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp)) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) - num, den = quick_cancel(d.impl.num, d.impl.den) + num, den = quick_cancel(get_num(d), d.impl.den) return _Div(num, den) else return d diff --git a/src/types.jl b/src/types.jl index bfbc20a9c..67f83a5e6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -76,6 +76,10 @@ function get_dict(x::BasicSymbolic) x.impl.dict end +function get_num(x::BasicSymbolic) + x.impl.num +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -303,7 +307,7 @@ function _isequal(a, b, E) elseif E === ADD || E === MUL coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b)) elseif E === DIV - isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den) + isequal(get_num(a), get_num(b)) && isequal(a.impl.den, b.impl.den) elseif E === POW isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base) elseif E === TERM @@ -349,7 +353,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt s.hash[] = h′ return h′ elseif E === DIV - return hash(s.impl.num, hash(s.impl.den, salt ⊻ DIV_SALT)) + return hash(get_num(s), hash(s.impl.den, salt ⊻ DIV_SALT)) elseif E === POW hash(s.impl.exp, hash(s.impl.base, salt ⊻ POW_SALT)) elseif E === TERM @@ -483,11 +487,11 @@ function _Div(::Type{T}, num, den; kwargs...) where {T} _iszero(num) && return zero(typeof(num)) _isone(den) && return num if isdiv(num) && isdiv(den) - return _Div(T, num.impl.num * den.impl.den, num.impl.den * den.impl.num) + return _Div(T, get_num(num) * den.impl.den, num.impl.den * get_num(den)) elseif isdiv(num) - return _Div(T, num.impl.num, num.impl.den * den) + return _Div(T, get_num(num), num.impl.den * den) elseif isdiv(den) - return _Div(T, num * den.impl.den, den.impl.num) + return _Div(T, num * den.impl.den, get_num(den)) end if den isa Number && _isone(-den) return -1 * num @@ -523,7 +527,7 @@ function _Div(num, den; kwargs...) end @inline function numerators(x) - isdiv(x) && return numerators(x.impl.num) + isdiv(x) && return numerators(get_num(x)) iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] end @@ -552,7 +556,7 @@ function toterm(t::BasicSymbolic{T}) where {T} end _Term(T, operation(t), args) elseif E === DIV - _Term(T, /, [t.impl.num, t.impl.den]) + _Term(T, /, [get_num(t), t.impl.den]) elseif E === POW _Term(T, ^, [t.impl.base, t.impl.exp]) else @@ -1291,11 +1295,11 @@ function *(a::SN, b::SN) # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) - _Div(a.impl.num * b.impl.num, a.impl.den * b.impl.den) + _Div(get_num(a) * get_num(b), a.impl.den * b.impl.den) elseif isdiv(a) - _Div(a.impl.num * b, a.impl.den) + _Div(get_num(a) * b, a.impl.den) elseif isdiv(b) - _Div(a * b.impl.num, b.impl.den) + _Div(a * get_num(b), b.impl.den) elseif ismul(a) && ismul(b) _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) @@ -1325,7 +1329,7 @@ function *(a::Number, b::SN) elseif isone(a) b elseif isdiv(b) - _Div(a * b.impl.num, b.impl.den) + _Div(a * get_num(b), b.impl.den) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) diff --git a/test/basics.jl b/test/basics.jl index d6fbfc5de..40b84ffa6 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,5 +1,5 @@ using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, - BasicSymbolic, term, get_name, get_coeff, get_dict + BasicSymbolic, term, get_name, get_coeff, get_dict, get_num using SymbolicUtils using IfElse: ifelse using Setfield @@ -346,21 +346,21 @@ end @testset "div" begin @syms x::SafeReal y::Real - @test issym((2x / 2y).impl.num) - @test get_coeff((2x / 3y).impl.num) == 2 + @test issym(get_num(2x / 2y)) + @test get_coeff(get_num(2x / 3y)) == 2 @test get_coeff((2x / 3y).impl.den) == 3 - @test get_coeff((2x / -3x).impl.num) == -2 + @test get_coeff(get_num(2x / -3x)) == -2 @test get_coeff((2x / -3x).impl.den) == 3 - @test get_coeff((2.5x / 3x).impl.num) == 2.5 + @test get_coeff(get_num(2.5x / 3x)) == 2.5 @test get_coeff((2.5x / 3x).impl.den) == 3 @test get_coeff((x / 3x).impl.den) == 3 @syms x y - @test issym((2x / 2y).impl.num) - @test get_coeff((2x / 3y).impl.num) == 2 + @test issym(get_num(2x / 2y)) + @test get_coeff(get_num(2x / 3y)) == 2 @test get_coeff((2x / 3y).impl.den) == 3 @test (2x / -3x) == -2 // 3 - @test (2.5x / 3x).impl.num == 2.5 + @test get_num(2.5x / 3x) == 2.5 @test (2.5x / 3x).impl.den == 3 @test (x / 3x) == 1 // 3 @test isequal(x / 1, x) From f044cecf4aefe76970703a3e45f013e0afb4dc6e Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 12:35:12 -0400 Subject: [PATCH 136/140] Define `get_den` --- src/polyform.jl | 12 ++++++------ src/types.jl | 26 +++++++++++++++----------- test/basics.jl | 14 +++++++------- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index d5d12e5b7..34817b1a2 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -296,11 +296,11 @@ end #add_divs(x, y) = x + y function add_divs(x, y) if isdiv(x) && isdiv(y) - return (get_num(x) * y.impl.den + get_num(y) * x.impl.den) / (x.impl.den * y.impl.den) + return (get_num(x) * get_den(y) + get_num(y) * get_den(x)) / (get_den(x) * get_den(y)) elseif isdiv(x) - return (get_num(x) + y * x.impl.den) / x.impl.den + return (get_num(x) + y * get_den(x)) / get_den(x) elseif isdiv(y) - return (x * y.impl.den + get_num(y)) / y.impl.den + return (x * get_den(y) + get_num(y)) / get_den(y) else x + y end @@ -384,7 +384,7 @@ function fraction_isone(x) end function needs_div_rules(x) - (isdiv(x) && !(get_num(x) isa Number) && !(x.impl.den isa Number)) || + (isdiv(x) && !(get_num(x) isa Number) && !(get_den(x) isa Number)) || (iscall(x) && operation(x) === (+) && count(has_div, arguments(x)) > 1) || (iscall(x) && any(needs_div_rules, arguments(x))) end @@ -417,11 +417,11 @@ Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) if ispow(d) && isdiv(d.impl.base) - return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (d.impl.base.impl.den^d.impl.exp)) + return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (get_den(d.impl.base)^d.impl.exp)) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) - num, den = quick_cancel(get_num(d), d.impl.den) + num, den = quick_cancel(get_num(d), get_den(d)) return _Div(num, den) else return d diff --git a/src/types.jl b/src/types.jl index 67f83a5e6..bb903e7ca 100644 --- a/src/types.jl +++ b/src/types.jl @@ -80,6 +80,10 @@ function get_num(x::BasicSymbolic) x.impl.num end +function get_den(x::BasicSymbolic) + x.impl.den +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -307,7 +311,7 @@ function _isequal(a, b, E) elseif E === ADD || E === MUL coeff_isequal(get_coeff(a), get_coeff(b)) && isequal(get_dict(a), get_dict(b)) elseif E === DIV - isequal(get_num(a), get_num(b)) && isequal(a.impl.den, b.impl.den) + isequal(get_num(a), get_num(b)) && isequal(get_den(a), get_den(b)) elseif E === POW isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base) elseif E === TERM @@ -353,7 +357,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt s.hash[] = h′ return h′ elseif E === DIV - return hash(get_num(s), hash(s.impl.den, salt ⊻ DIV_SALT)) + return hash(get_num(s), hash(get_den(s), salt ⊻ DIV_SALT)) elseif E === POW hash(s.impl.exp, hash(s.impl.base, salt ⊻ POW_SALT)) elseif E === TERM @@ -487,11 +491,11 @@ function _Div(::Type{T}, num, den; kwargs...) where {T} _iszero(num) && return zero(typeof(num)) _isone(den) && return num if isdiv(num) && isdiv(den) - return _Div(T, get_num(num) * den.impl.den, num.impl.den * get_num(den)) + return _Div(T, get_num(num) * get_den(den), get_den(num) * get_num(den)) elseif isdiv(num) - return _Div(T, get_num(num), num.impl.den * den) + return _Div(T, get_num(num), get_den(num) * den) elseif isdiv(den) - return _Div(T, num * den.impl.den, get_num(den)) + return _Div(T, num * get_den(den), get_num(den)) end if den isa Number && _isone(-den) return -1 * num @@ -531,7 +535,7 @@ end iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] end -@inline denominators(x) = isdiv(x) ? numerators(x.impl.den) : Any[1] +@inline denominators(x) = isdiv(x) ? numerators(get_den(x)) : Any[1] function _Pow(::Type{T}, base, exp; kwargs...) where {T} _iszero(exp) && return 1 @@ -556,7 +560,7 @@ function toterm(t::BasicSymbolic{T}) where {T} end _Term(T, operation(t), args) elseif E === DIV - _Term(T, /, [get_num(t), t.impl.den]) + _Term(T, /, [get_num(t), get_den(t)]) elseif E === POW _Term(T, ^, [t.impl.base, t.impl.exp]) else @@ -1295,11 +1299,11 @@ function *(a::SN, b::SN) # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) - _Div(get_num(a) * get_num(b), a.impl.den * b.impl.den) + _Div(get_num(a) * get_num(b), get_den(a) * get_den(b)) elseif isdiv(a) - _Div(get_num(a) * b, a.impl.den) + _Div(get_num(a) * b, get_den(a)) elseif isdiv(b) - _Div(a * get_num(b), b.impl.den) + _Div(a * get_num(b), get_den(b)) elseif ismul(a) && ismul(b) _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) @@ -1329,7 +1333,7 @@ function *(a::Number, b::SN) elseif isone(a) b elseif isdiv(b) - _Div(a * get_num(b), b.impl.den) + _Div(a * get_num(b), get_den(b)) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) diff --git a/test/basics.jl b/test/basics.jl index 40b84ffa6..2161f1648 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,5 +1,5 @@ using SymbolicUtils: Symbolic, FnType, symtype, operation, arguments, issym, isterm, - BasicSymbolic, term, get_name, get_coeff, get_dict, get_num + BasicSymbolic, term, get_name, get_coeff, get_dict, get_num, get_den using SymbolicUtils using IfElse: ifelse using Setfield @@ -348,20 +348,20 @@ end @syms x::SafeReal y::Real @test issym(get_num(2x / 2y)) @test get_coeff(get_num(2x / 3y)) == 2 - @test get_coeff((2x / 3y).impl.den) == 3 + @test get_coeff(get_den(2x / 3y)) == 3 @test get_coeff(get_num(2x / -3x)) == -2 - @test get_coeff((2x / -3x).impl.den) == 3 + @test get_coeff(get_den(2x / -3x)) == 3 @test get_coeff(get_num(2.5x / 3x)) == 2.5 - @test get_coeff((2.5x / 3x).impl.den) == 3 - @test get_coeff((x / 3x).impl.den) == 3 + @test get_coeff(get_den(2.5x / 3x)) == 3 + @test get_coeff(get_den(x / 3x)) == 3 @syms x y @test issym(get_num(2x / 2y)) @test get_coeff(get_num(2x / 3y)) == 2 - @test get_coeff((2x / 3y).impl.den) == 3 + @test get_coeff(get_den(2x / 3y)) == 3 @test (2x / -3x) == -2 // 3 @test get_num(2.5x / 3x) == 2.5 - @test (2.5x / 3x).impl.den == 3 + @test get_den(2.5x / 3x) == 3 @test (x / 3x) == 1 // 3 @test isequal(x / 1, x) @test isequal(x / -1, -x) From 8e733138de10a576e86ea7ad87f8e09047119632 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 13:03:25 -0400 Subject: [PATCH 137/140] Define `get_base` --- src/polyform.jl | 20 ++++++++++---------- src/types.jl | 16 ++++++++++------ 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index 34817b1a2..cf34a189a 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -268,7 +268,7 @@ function polyform_factors(d, pvar2sym, sym2term) if ispow(x) && x.impl.exp isa Integer && x.impl.exp > 0 # here we do want to recurse one level, that's why it's wrong to just # use Fs = Union{typeof(+), typeof(*)} here. - _Pow(PolyForm(x.impl.base, pvar2sym, sym2term), x.impl.exp) + _Pow(PolyForm(get_base(x), pvar2sym, sym2term), x.impl.exp) else PolyForm(x, pvar2sym, sym2term) end @@ -416,8 +416,8 @@ But it will simplify `(x - 5)^2*(x - 3) / (x - 5)` to `(x - 5)*(x - 3)`. Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) - if ispow(d) && isdiv(d.impl.base) - return quick_cancel((get_num(d.impl.base)^d.impl.exp) / (get_den(d.impl.base)^d.impl.exp)) + if ispow(d) && isdiv(get_base(d)) + return quick_cancel((get_num(get_base(d))^d.impl.exp) / (get_den(get_base(d))^d.impl.exp)) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) @@ -502,17 +502,17 @@ end # mul, pow case function quick_mulpow(x, y) y.impl.exp isa Number || return (x, y) - if haskey(get_dict(x), y.impl.base) + if haskey(get_dict(x), get_base(y)) d = copy(get_dict(x)) - if get_dict(x)[y.impl.base] > y.impl.exp - d[y.impl.base] -= y.impl.exp + if get_dict(x)[get_base(y)] > y.impl.exp + d[get_base(y)] -= y.impl.exp den = 1 - elseif get_dict(x)[y.impl.base] == y.impl.exp - delete!(d, y.impl.base) + elseif get_dict(x)[get_base(y)] == y.impl.exp + delete!(d, get_base(y)) den = 1 else - den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base]) - delete!(d, y.impl.base) + den = _Pow(symtype(y), get_base(y), y.impl.exp-d[get_base(y)]) + delete!(d, get_base(y)) end return _Mul(symtype(x), get_coeff(x), d), den else diff --git a/src/types.jl b/src/types.jl index bb903e7ca..d24005edb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -84,6 +84,10 @@ function get_den(x::BasicSymbolic) x.impl.den end +function get_base(x::BasicSymbolic) + x.impl.base +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -313,7 +317,7 @@ function _isequal(a, b, E) elseif E === DIV isequal(get_num(a), get_num(b)) && isequal(get_den(a), get_den(b)) elseif E === POW - isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base) + isequal(a.impl.exp, b.impl.exp) && isequal(get_base(a), get_base(b)) elseif E === TERM a1 = arguments(a) a2 = arguments(b) @@ -359,7 +363,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt elseif E === DIV return hash(get_num(s), hash(get_den(s), salt ⊻ DIV_SALT)) elseif E === POW - hash(s.impl.exp, hash(s.impl.base, salt ⊻ POW_SALT)) + hash(s.impl.exp, hash(get_base(s), salt ⊻ POW_SALT)) elseif E === TERM !iszero(salt) && return hash(hash(s, zero(UInt)), salt) h = s.hash[] @@ -562,7 +566,7 @@ function toterm(t::BasicSymbolic{T}) where {T} elseif E === DIV _Term(T, /, [get_num(t), get_den(t)]) elseif E === POW - _Term(T, ^, [t.impl.base, t.impl.exp]) + _Term(T, ^, [get_base(t), t.impl.exp]) else error_on_type() end @@ -605,7 +609,7 @@ end function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) for x in xs if ispow(x) && x.impl.exp isa Number - d[x.impl.base] = x.impl.exp + get(d, x.impl.base, 0) + d[get_base(x)] = x.impl.exp + get(d, get_base(x), 0) elseif x isa Number coeff *= x elseif ismul(x) @@ -629,7 +633,7 @@ function makepow(a, b) base = a exp = b if ispow(a) - base = a.impl.base + base = get_base(a) exp = a.impl.exp * b end base, exp @@ -1311,7 +1315,7 @@ function *(a::SN, b::SN) if b.impl.exp isa Number _Mul(mul_t(a, b), get_coeff(a), - _merge(+, get_dict(a), Base.ImmutableDict(b.impl.base => b.impl.exp), + _merge(+, get_dict(a), Base.ImmutableDict(get_base(b) => b.impl.exp), filter = _iszero)) else _Mul(mul_t(a, b), get_coeff(a), From f7a77d6c6832e3bf61fc7224c3e4164514324d77 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 13:07:26 -0400 Subject: [PATCH 138/140] Define `get_exp` --- src/polyform.jl | 16 ++++++++-------- src/types.jl | 20 ++++++++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/polyform.jl b/src/polyform.jl index cf34a189a..04e035781 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -265,10 +265,10 @@ end function polyform_factors(d, pvar2sym, sym2term) make(xs) = map(xs) do x - if ispow(x) && x.impl.exp isa Integer && x.impl.exp > 0 + if ispow(x) && get_exp(x) isa Integer && get_exp(x) > 0 # here we do want to recurse one level, that's why it's wrong to just # use Fs = Union{typeof(+), typeof(*)} here. - _Pow(PolyForm(get_base(x), pvar2sym, sym2term), x.impl.exp) + _Pow(PolyForm(get_base(x), pvar2sym, sym2term), get_exp(x)) else PolyForm(x, pvar2sym, sym2term) end @@ -417,7 +417,7 @@ Has optimized processes for `Mul` and `Pow` terms. """ function quick_cancel(d) if ispow(d) && isdiv(get_base(d)) - return quick_cancel((get_num(get_base(d))^d.impl.exp) / (get_den(get_base(d))^d.impl.exp)) + return quick_cancel((get_num(get_base(d))^get_exp(d)) / (get_den(get_base(d))^get_exp(d))) elseif ismul(d) && any(isdiv, arguments(d)) return prod(arguments(d)) elseif isdiv(d) @@ -501,17 +501,17 @@ end # mul, pow case function quick_mulpow(x, y) - y.impl.exp isa Number || return (x, y) + get_exp(y) isa Number || return (x, y) if haskey(get_dict(x), get_base(y)) d = copy(get_dict(x)) - if get_dict(x)[get_base(y)] > y.impl.exp - d[get_base(y)] -= y.impl.exp + if get_dict(x)[get_base(y)] > get_exp(y) + d[get_base(y)] -= get_exp(y) den = 1 - elseif get_dict(x)[get_base(y)] == y.impl.exp + elseif get_dict(x)[get_base(y)] == get_exp(y) delete!(d, get_base(y)) den = 1 else - den = _Pow(symtype(y), get_base(y), y.impl.exp-d[get_base(y)]) + den = _Pow(symtype(y), get_base(y), get_exp(y)-d[get_base(y)]) delete!(d, get_base(y)) end return _Mul(symtype(x), get_coeff(x), d), den diff --git a/src/types.jl b/src/types.jl index d24005edb..30e8f4e65 100644 --- a/src/types.jl +++ b/src/types.jl @@ -88,6 +88,10 @@ function get_base(x::BasicSymbolic) x.impl.base end +function get_exp(x::BasicSymbolic) + x.impl.exp +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -317,7 +321,7 @@ function _isequal(a, b, E) elseif E === DIV isequal(get_num(a), get_num(b)) && isequal(get_den(a), get_den(b)) elseif E === POW - isequal(a.impl.exp, b.impl.exp) && isequal(get_base(a), get_base(b)) + isequal(get_exp(a), get_exp(b)) && isequal(get_base(a), get_base(b)) elseif E === TERM a1 = arguments(a) a2 = arguments(b) @@ -363,7 +367,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt elseif E === DIV return hash(get_num(s), hash(get_den(s), salt ⊻ DIV_SALT)) elseif E === POW - hash(s.impl.exp, hash(get_base(s), salt ⊻ POW_SALT)) + hash(get_exp(s), hash(get_base(s), salt ⊻ POW_SALT)) elseif E === TERM !iszero(salt) && return hash(hash(s, zero(UInt)), salt) h = s.hash[] @@ -566,7 +570,7 @@ function toterm(t::BasicSymbolic{T}) where {T} elseif E === DIV _Term(T, /, [get_num(t), get_den(t)]) elseif E === POW - _Term(T, ^, [get_base(t), t.impl.exp]) + _Term(T, ^, [get_base(t), get_exp(t)]) else error_on_type() end @@ -608,8 +612,8 @@ end function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}()) for x in xs - if ispow(x) && x.impl.exp isa Number - d[get_base(x)] = x.impl.exp + get(d, get_base(x), 0) + if ispow(x) && get_exp(x) isa Number + d[get_base(x)] = get_exp(x) + get(d, get_base(x), 0) elseif x isa Number coeff *= x elseif ismul(x) @@ -634,7 +638,7 @@ function makepow(a, b) exp = b if ispow(a) base = get_base(a) - exp = a.impl.exp * b + exp = get_exp(a) * b end base, exp end @@ -1312,10 +1316,10 @@ function *(a::SN, b::SN) _Mul(mul_t(a, b), get_coeff(a) * get_coeff(b), _merge(+, get_dict(a), get_dict(b), filter = _iszero)) elseif ismul(a) && ispow(b) - if b.impl.exp isa Number + if get_exp(b) isa Number _Mul(mul_t(a, b), get_coeff(a), - _merge(+, get_dict(a), Base.ImmutableDict(get_base(b) => b.impl.exp), + _merge(+, get_dict(a), Base.ImmutableDict(get_base(b) => get_exp(b)), filter = _iszero)) else _Mul(mul_t(a, b), get_coeff(a), From b0c5de1ad33cb6e131bd1afaf4c027638adb7571 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 13:32:13 -0400 Subject: [PATCH 139/140] Define `get_val` --- src/code.jl | 4 ++-- src/inspect.jl | 2 +- src/matchers.jl | 4 ++-- src/methods.jl | 2 +- src/ordering.jl | 6 +++--- src/polyform.jl | 2 +- src/substitute.jl | 4 ++-- src/types.jl | 38 +++++++++++++++++++++----------------- src/utils.jl | 2 +- test/types.jl | 8 ++++---- 10 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/code.jl b/src/code.jl index 113a6dcb2..b9c8ec9d1 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, _Sym, Term, iscall, operation, arguments, issym, - isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm + isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm, get_val import SymbolicIndexingInterface: symbolic_type, NotSymbolic ##== state management ==## @@ -183,7 +183,7 @@ function toexpr(O, st) O = substitute_name(O, st) return issym(O) ? nameof(O) : toexpr(O, st) elseif isconst(O) - return toexpr(O.impl.val, st) + return toexpr(get_val(O), st) end O = substitute_name(O, st) diff --git a/src/inspect.jl b/src/inspect.jl index 8dc7eaf6c..715082496 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -9,7 +9,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) str = if issym(x) string(exprtype(x), "(", x, ")") elseif isconst(x) - string(x.impl.val) + string(get_val(x)) elseif isadd(x) string(exprtype(x), (scalar = get_coeff(x), coeffs = Tuple(k => v for (k, v) in get_dict(x)))) diff --git a/src/matchers.jl b/src/matchers.jl index 5eb1df4d4..03d7b33f3 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -7,7 +7,7 @@ # function matcher(val::Any) if isconst(val) - slot = val.impl.val + slot = get_val(val) return matcher(slot) elseif iscall(val) return term_matcher(val) @@ -16,7 +16,7 @@ function matcher(val::Any) if islist(data) cd = car(data) if isconst(cd) - cd = cd.impl.val + cd = get_val(cd) end if isequal(cd, val) return next(bindings, 1) diff --git a/src/methods.jl b/src/methods.jl index 7e5850a4c..f6c90ae52 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -190,7 +190,7 @@ for f in [!, ~] promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool function (::$(typeof(f)))(s::Symbolic{Bool}) if isconst(s) - s = s.impl.val + s = get_val(s) return !s end _Term(Bool, !, [s]) diff --git a/src/ordering.jl b/src/ordering.jl index 56661c816..623b0cecd 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -27,7 +27,7 @@ function get_degrees(expr) elseif iscall(expr) op = operation(expr) args = sorted_arguments(expr) - if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].impl.val isa Number)) + if op == (^) && (args[2] isa Number || (isconst(args[2]) && get_val(args[2]) isa Number)) return map(get_degrees(args[1])) do (base, pow) (base => pow * args[2]) end @@ -81,11 +81,11 @@ end function <ₑ(a::BasicSymbolic, b::BasicSymbolic) aisconst = isconst(a) if aisconst - a = a.impl.val + a = get_val(a) end bisconst = isconst(b) if bisconst - b = b.impl.val + b = get_val(b) end if aisconst || bisconst return a <ₑ b diff --git a/src/polyform.jl b/src/polyform.jl index 04e035781..b18d6bd47 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -96,7 +96,7 @@ _isone(p::PolyForm) = isone(p.p) function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) if isconst(x) - x = x.impl.val + x = get_val(x) end if x isa Number return x diff --git a/src/substitute.jl b/src/substitute.jl index 025c9dbcc..3df62cf37 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -23,7 +23,7 @@ function substitute(expr, dict; fold=true) args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) if isconst(x) - x′ = x′.impl.val + x′ = get_val(x′) end canfold = canfold && !(x′ isa Symbolic) x′ @@ -58,7 +58,7 @@ function _occursin(needle, haystack) args = arguments(haystack) for arg in args if isconst(arg) - arg = arg.impl.val + arg = get_val(arg) end if needle isa Integer || needle isa AbstractFloat isequal(needle, arg) && return true diff --git a/src/types.jl b/src/types.jl index 30e8f4e65..779a259a4 100644 --- a/src/types.jl +++ b/src/types.jl @@ -92,6 +92,10 @@ function get_exp(x::BasicSymbolic) x.impl.exp end +function get_val(x::BasicSymbolic) + x.impl.val +end + # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") @@ -327,7 +331,7 @@ function _isequal(a, b, E) a2 = arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) elseif E === CONST - isequal(a.impl.val, b.impl.val) + isequal(get_val(a), get_val(b)) else error_on_type() end @@ -378,7 +382,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt s.hash[] = h′ return h′ elseif E === CONST - return hash(s.impl.val, salt ⊻ COS_SALT) + return hash(get_val(s), salt ⊻ COS_SALT) else error_on_type() end @@ -452,14 +456,14 @@ end function _iszero(x::BasicSymbolic) @match x.impl begin - Const(_...) => iszero(x.impl.val) + Const(_...) => iszero(get_val(x)) _ => false end end function _isone(x::BasicSymbolic) @match x.impl begin - Const(_...) => isone(x.impl.val) + Const(_...) => isone(get_val(x)) _ => false end end @@ -833,7 +837,7 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) if isconst(t) - val = t.impl.val + val = get_val(t) return isnegative(val) end if iscall(t) && operation(t) === (*) @@ -872,7 +876,7 @@ function remove_minus(t) args = arguments(t) arg1 = args[1] if isconst(arg1) - arg1 = arg1.impl.val + arg1 = get_val(arg1) end @assert arg1 < 0 Any[-arg1, args[2:end]...] @@ -911,14 +915,14 @@ end function show_mul(io, args) if isconst(args) - print(io, args.impl.val) + print(io, get_val(args)) return end length(args) == 1 && return print_arg(io, *, args[1]) arg1 = args[1] if isconst(arg1) - arg1 = arg1.impl.val + arg1 = get_val(arg1) end minus = arg1 isa Number && arg1 == -1 @@ -930,7 +934,7 @@ function show_mul(io, args) nostar = minus || unit || (!paren_scalar && arg1 isa Number && - !(isconst(args[2]) && args[2].impl.val isa Number)) + !(isconst(args[2]) && get_val(args[2]) isa Number)) for (i, t) in enumerate(args) if i != 1 @@ -1021,7 +1025,7 @@ showraw(t) = showraw(stdout, t) function Base.show(io::IO, v::BasicSymbolic) @match v.impl begin Sym(_...) => Base.show_unquoted(io, get_name(v)) - Const(_...) => print(io, v.impl.val) + Const(_...) => print(io, get_val(v)) _ => show_term(io, v) end end @@ -1235,10 +1239,10 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) if isconst(a) - return a.impl.val + b + return get_val(a) + b end if isconst(b) - return b.impl.val + a + return get_val(b) + a end !issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) @@ -1255,7 +1259,7 @@ function +(a::SN, b::SN) end function +(a::Number, b::SN) if isconst(b) - return a + b.impl.val + return a + get_val(b) end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b @@ -1270,7 +1274,7 @@ end function -(a::SN) if isconst(a) - v = a.impl.val + v = get_val(a) mv = -v return _Const(mv) end @@ -1299,10 +1303,10 @@ mul_t(a) = promote_symtype(*, symtype(a)) function *(a::SN, b::SN) if isconst(a) - return a.impl.val * b + return get_val(a) * b end if isconst(b) - return b.impl.val * a + return get_val(b) * a end # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) @@ -1333,7 +1337,7 @@ function *(a::SN, b::SN) end function *(a::Number, b::SN) if isconst(b) - return a * b.impl.val + return a * get_val(b) end !issafecanon(*, b) && return term(*, a, b) if iszero(a) diff --git a/src/utils.jl b/src/utils.jl index 4508acad2..3ced418da 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -66,7 +66,7 @@ sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T function is_literal_number(x) if isconst(x) - x = x.impl.val + x = get_val(x) end x isa Number end diff --git a/test/types.jl b/test/types.jl index b2d7c345e..6bcd510ba 100644 --- a/test/types.jl +++ b/test/types.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name +using SymbolicUtils: Symbolic, BasicSymbolic, _Sym, _Term, _Const, _Add, get_name, get_val @testset "Expronicon generated constructors" begin s1 = Sym(:abc) @@ -133,17 +133,17 @@ end @test typeof(c1) == BasicSymbolic{Float64} @test c1.metadata == SymbolicUtils.NO_METADATA @test c1.hash[] == SymbolicUtils.EMPTY_HASH - @test c1.impl.val == 1.0 + @test get_val(c1) == 1.0 c2 = _Const(big"123456789012345678901234567890") @test typeof(c2) == BasicSymbolic{BigInt} @test c2.metadata == SymbolicUtils.NO_METADATA @test c2.hash[] == SymbolicUtils.EMPTY_HASH - @test c2.impl.val == big"123456789012345678901234567890" + @test get_val(c2) == big"123456789012345678901234567890" c3 = _Const(big"1.23456789012345678901") @test typeof(c3) == BasicSymbolic{BigFloat} @test c3.metadata == SymbolicUtils.NO_METADATA @test c3.hash[] == SymbolicUtils.EMPTY_HASH - @test c3.impl.val == big"1.23456789012345678901" + @test get_val(c3) == big"1.23456789012345678901" end end From 73805e226a1dd9e7527360358106300cc297b7fd Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Fri, 27 Sep 2024 13:39:09 -0400 Subject: [PATCH 140/140] Fix `eltype` in doctest --- docs/src/manual/rewrite.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 047bef71a..feac062ec 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need @rule(+(~~xs) => ~~xs)(x + y + z) # output -3-element view(::Vector{Any}, 1:3) with eltype Any: +3-element view(::Vector{SymbolicUtils.BasicSymbolic}, 1:3) with eltype SymbolicUtils.BasicSymbolic: z y x