Skip to content

Refactor BasicSymbolic struct to separate metadata from hash consing #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2ab70b3
Create `BasicSymbolicImpl` struct to separate metadata from hash consing
bowenszhu Feb 13, 2025
17408e8
Rename `BasicSymbolicImpl` field from `impl` to `expr`
bowenszhu Feb 13, 2025
295a1df
Adapt `BasicSymbolic` hash consing constructor
bowenszhu Feb 13, 2025
1fe09e5
Adapt `setproperties` with new `BasicSymbolic` struct
bowenszhu Feb 13, 2025
b4932e6
Call custom constructor of `Mul` in `maybe_intcoeff`
bowenszhu Feb 13, 2025
9da1a9e
Fix: `hash2` of `BasicSymbolic` equals its `BasicSymbolicImpl`
bowenszhu Feb 13, 2025
56c990c
Operate CSE `topological_sort` `dfs` on `BasicSymbolicImpl`
bowenszhu Feb 13, 2025
4fadeab
Add `isexpr` & `iscall` methods for `BasicSymbolicImpl`
bowenszhu Feb 13, 2025
d7737dd
Remove `metadata` from `hash2`
bowenszhu Feb 13, 2025
65cc88f
Change `Base.isequal` for `BasicSymbolicImpl`
bowenszhu Feb 14, 2025
df49641
Modify flyweight factory for `BasicSymbolicImpl`
bowenszhu Feb 14, 2025
db30dc8
Create `MetadataImpl` struct to keep track of metadata tree
bowenszhu Feb 17, 2025
580ad33
Modify `getproperty(::BasicSymbolic)` for metadata
bowenszhu Feb 17, 2025
831f8fc
Add `isequal` method for `MetadataImpl`
bowenszhu Feb 17, 2025
29a6eb4
Modify `isequal_with_metadata` with new `BasicSymbolic` structure
bowenszhu Feb 17, 2025
f8d0e78
Add `getmetadata` methods bc `metadata` kwarg takes outer-scope function
bowenszhu Feb 17, 2025
e9f4ff8
Modify `BasicSymbolic` constructors with new struct structure
bowenszhu Feb 17, 2025
5293e98
Add `metadata_children` function for accessing metadata tree
bowenszhu Feb 17, 2025
4373146
Modify hash consing tests with new `BasicSymbolic` struct
bowenszhu Feb 17, 2025
4a99619
Modify rewrite metadata tests
bowenszhu Feb 17, 2025
440c17b
Refactor `-(::SN, ::SN)` for easier debugging
bowenszhu Feb 20, 2025
4d00726
Make `BasicSymbolicImpl` children `BasicSymbolicImpl`
bowenszhu Feb 20, 2025
2fcac04
Make `getproperty` return `BasicSymbolic` if applicable
bowenszhu Feb 20, 2025
d03e0ef
`arguments` wraps `BasicSymbolicImpl` and `MetadataImpl`
bowenszhu Feb 20, 2025
2eae851
Revert "Add `metadata_children` function for accessing metadata tree"
bowenszhu Feb 20, 2025
cc5e273
Revert "Modify rewrite metadata tests"
bowenszhu Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,9 @@ function topological_sort(graph)
visited = IdDict()

function dfs(node)
if node isa BasicSymbolic
node = node.expr
end
if haskey(visited, node)
return visited[node]
end
Expand Down
164 changes: 119 additions & 45 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,64 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT)
const ENABLE_HASHCONSING = Ref(true)

@compactify show_methods=false begin
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
metadata::Metadata = NO_METADATA
end
mutable struct Sym{T} <: BasicSymbolic{T}
@abstract mutable struct BasicSymbolicImpl{T} end
mutable struct Sym{T} <: BasicSymbolicImpl{T}
name::Symbol = :OOF
end
mutable struct Term{T} <: BasicSymbolic{T}
mutable struct Term{T} <: BasicSymbolicImpl{T}
f::Any = identity # base/num if Pow; issorted if Add/Dict
arguments::Vector{Any} = EMPTY_ARGS
hash::RefValue{UInt} = EMPTY_HASH
hash2::RefValue{UInt} = EMPTY_HASH
end
mutable struct Mul{T} <: BasicSymbolic{T}
mutable struct Mul{T} <: BasicSymbolicImpl{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
hash2::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
mutable struct Add{T} <: BasicSymbolic{T}
mutable struct Add{T} <: BasicSymbolicImpl{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
hash2::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
mutable struct Div{T} <: BasicSymbolic{T}
mutable struct Div{T} <: BasicSymbolicImpl{T}
num::Any = 1
den::Any = 1
simplified::Bool = false
arguments::Vector{Any} = EMPTY_ARGS
end
mutable struct Pow{T} <: BasicSymbolic{T}
mutable struct Pow{T} <: BasicSymbolicImpl{T}
base::Any = 1
exp::Any = 1
arguments::Vector{Any} = EMPTY_ARGS
end
end

struct MetadataImpl
this::Metadata
children::Vector{Any}
end

@kwdef struct BasicSymbolic{T} <: Symbolic{T}
expr::BasicSymbolicImpl{T}
meta::MetadataImpl
end

function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
ScalarSymbolic()
end

function exprtype(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
exprtype(x.expr)
end
function exprtype(expr::BasicSymbolicImpl)
@compactified expr::BasicSymbolicImpl begin
Term => TERM
Add => ADD
Mul => MUL
Expand All @@ -81,7 +92,17 @@ function exprtype(x::BasicSymbolic)
end
end

const wvd = WeakValueDict{UInt, BasicSymbolic}()
function Base.getproperty(x::BasicSymbolic, sym::Symbol)
if sym === :metadata
return getfield(x, :meta).this
elseif sym === :expr || sym === :meta
return getfield(x, sym)
else
return getproperty(x.expr, sym)
end
end

const wvd = WeakValueDict{UInt, BasicSymbolicImpl}()

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
Expand All @@ -96,10 +117,11 @@ const SIMPLIFIED = 0x01 << 0
#@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED)

function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T
nt = getproperties(obj)
nt_new = merge(nt, patch)
expr = obj.expr
nt = getproperties(expr)
nt_new = merge(nt, (metadata = obj.metadata,), patch)
# Call outer constructor because hash consing cannot be applied in inner constructor
@compactified obj::BasicSymbolic begin
@compactified expr::BasicSymbolicImpl begin
Sym => Sym{T}(nt_new.name; nt_new...)
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
Expand Down Expand Up @@ -128,9 +150,12 @@ symtype(x) = typeof(x)
@inline symtype(::Type{<:Symbolic{T}}) where T = T

# We're returning a function pointer
@inline function operation(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
Term => x.f
function operation(x::BasicSymbolic)
operation(x.expr)
end
@inline function operation(expr::BasicSymbolicImpl)
@compactified expr::BasicSymbolicImpl begin
Term => expr.f
Add => (+)
Mul => (*)
Div => (/)
Expand All @@ -144,7 +169,7 @@ end

function TermInterface.sorted_arguments(x::BasicSymbolic)
args = arguments(x)
@compactified x::BasicSymbolic begin
@compactified x.expr::BasicSymbolicImpl begin
Add => @goto ADD
Mul => @goto MUL
_ => return args
Expand All @@ -169,7 +194,10 @@ end
TermInterface.children(x::BasicSymbolic) = arguments(x)
TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x)
function TermInterface.arguments(x::BasicSymbolic)
@compactified x::BasicSymbolic begin
arguments(x.expr)
end
function TermInterface.arguments(x::BasicSymbolicImpl)
@compactified x::BasicSymbolicImpl begin
Term => return x.arguments
Add => @goto ADDMUL
Mul => @goto ADDMUL
Expand Down Expand Up @@ -216,10 +244,20 @@ function TermInterface.arguments(x::BasicSymbolic)
return args
end

isexpr(s::BasicSymbolic) = !issym(s)
iscall(s::BasicSymbolic) = isexpr(s)
isexpr(s::BasicSymbolic) = isexpr(s.expr)
isexpr(expr::BasicSymbolicImpl) = !issym(expr)
iscall(s::BasicSymbolic) = iscall(s.expr)
iscall(expr::BasicSymbolicImpl) = isexpr(expr)

@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false
@inline function isa_SymType(T::Val{S}, x) where {S}
if x isa BasicSymbolic
Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x.expr)
elseif x isa BasicSymbolicImpl
Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolicImpl), T, x)
else
false
end
end

"""
issym(x)
Expand Down Expand Up @@ -253,7 +291,10 @@ function _allarequal(xs, ys; comparator = isequal)::Bool
return true
end

function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
function Base.isequal(a::BasicSymbolic, b::BasicSymbolic)
isequal(a.expr, b.expr)
end
function Base.isequal(a::BasicSymbolicImpl{T}, b::BasicSymbolicImpl{S}) where {T,S}
a === b && return true

E = exprtype(a)
Expand All @@ -262,6 +303,12 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
T === S || return false
return _isequal(a, b, E)::Bool
end
function Base.isequal(a::MetadataImpl, b::MetadataImpl)
(a === b) ||
(isequal_with_metadata(a.this, b.this) &&
isequal_with_metadata(a.children, b.children))
end

function _isequal(a, b, E; comparator = isequal)
if E === SYM
nameof(a) === nameof(b)
Expand Down Expand Up @@ -303,7 +350,7 @@ function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool w
E === exprtype(b) || return false

T === S || return false
_isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false
_isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal(a.meta, b.meta) || return false
end

"""
Expand Down Expand Up @@ -395,7 +442,14 @@ end
Base.one( s::Symbolic) = one( symtype(s))
Base.zero(s::Symbolic) = zero(symtype(s))

Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name")
Base.nameof(s::BasicSymbolic) = nameof(s.expr)
function Base.nameof(s::BasicSymbolicImpl)
if issym(s)
s.name
else
error("None Sym BasicSymbolic doesn't have a name")
end
end

## This is much faster than hash of an array of Any
hashvec(xs, z) = foldr(hash, xs, init=z)
Expand Down Expand Up @@ -457,8 +511,9 @@ hash2(s, salt::UInt) = hash(s, salt)
function hash2(n::T, salt::UInt) where {T <: Number}
hash(T, hash(n, salt))
end
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash2(s::BasicSymbolic) = hash2(s.expr, zero(UInt))
hash2(s::BasicSymbolicImpl) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolicImpl{T}, salt::UInt)::UInt where {T}
E = exprtype(s)
h::UInt = 0
if E === SYM
Expand Down Expand Up @@ -488,7 +543,7 @@ function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
else
error_on_type()
end
h = hash(metadata(s), hash(T, h))
h = hash(T, h)
if hasproperty(s, :hash2)
s.hash2[] = h
end
Expand Down Expand Up @@ -520,31 +575,35 @@ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.h
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
original behavior of those functions.
"""
function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
function BasicSymbolicImpl(s::BasicSymbolicImpl)::BasicSymbolicImpl
if !ENABLE_HASHCONSING[]
return s
end
h = hash2(s)
t = get!(wvd, h, s)
if t === s || isequal_with_metadata(t, s)
if t === s || isequal(t, s)
return t
else
return s
end
end

function Sym{T}(name::Symbol; kw...) where {T}
function Sym{T}(name::Symbol; metadata = NO_METADATA, kw...) where {T}
s = Sym{T}(; name, kw...)
BasicSymbolic(s)
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, Vector())
BasicSymbolic(bsi, mdi)
end

function Term{T}(f, args; kw...) where T
function Term{T}(f, args; metadata = NO_METADATA, kw...) where T
if eltype(args) !== Any
args = convert(Vector{Any}, args)
end

s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
BasicSymbolic(s)
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, getmetadata.(args))
BasicSymbolic(bsi, mdi)
end

function Term(f, args; metadata=NO_METADATA)
Expand All @@ -564,8 +623,10 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
end
end

s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...)
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
BasicSymbolic(bsi, mdi)
end

function Mul(T, a, b; metadata=NO_METADATA, kw...)
Expand All @@ -580,8 +641,10 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
else
coeff = a
dict = b
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
BasicSymbolic(s)
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), arguments=[], issorted=RefValue(false), kw...)
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
BasicSymbolic(bsi, mdi)
end
end

Expand All @@ -601,7 +664,7 @@ ratio(x::Rat,y::Rat) = x//y
function maybe_intcoeff(x)
if ismul(x)
if x.coeff isa Rational && isone(x.coeff.den)
Mul{symtype(x)}(; coeff=x.coeff.num, dict=x.dict, x.metadata, arguments=[], issorted=RefValue(false))
Mul(symtype(x), x.coeff.num, x.dict; x.metadata)
else
x
end
Expand All @@ -612,7 +675,7 @@ function maybe_intcoeff(x)
end
end

function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
function Div{T}(n, d, simplified=false; metadata=NO_METADATA, kwargs...) where {T}
if T<:Number && !(T<:SafeReal)
n, d = quick_cancel(n, d)
end
Expand Down Expand Up @@ -646,8 +709,10 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
end
end

s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
BasicSymbolic(s)
s = Div{T}(; num=n, den=d, simplified, arguments=[])
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
BasicSymbolic(bsi, mdi)
end

function Div(n,d, simplified=false; kw...)
Expand All @@ -664,8 +729,10 @@ end
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
_iszero(b) && return 1
_isone(b) && return a
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
BasicSymbolic(s)
s = Pow{T}(; base=a, exp=b, arguments=[])
bsi = BasicSymbolicImpl(s)
mdi = MetadataImpl(metadata, getmetadata.(arguments(s)))
BasicSymbolic(bsi, mdi)
end

function Pow(a, b; metadata = NO_METADATA, kwargs...)
Expand Down Expand Up @@ -856,6 +923,10 @@ end
metadata(s::Symbolic) = s.metadata
metadata(s::Symbolic, meta) = Setfield.@set! s.metadata = meta

function metadata_children(s::BasicSymbolic)
s.meta.children
end

function hasmetadata(s::Symbolic, ctx)
metadata(s) isa AbstractDict && haskey(metadata(s), ctx)
end
Expand All @@ -874,6 +945,7 @@ _issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^))

issafecanon(f, ss...) = all(x->issafecanon(f, x), ss)

getmetadata(s) = metadata(s)
function getmetadata(s::Symbolic, ctx)
md = metadata(s)
if md isa AbstractDict
Expand All @@ -882,11 +954,13 @@ function getmetadata(s::Symbolic, ctx)
throw(ArgumentError("$s does not have metadata for $ctx"))
end
end

function getmetadata(s::Symbolic, ctx, default)
md = metadata(s)
md isa AbstractDict ? get(md, ctx, default) : default
end
function getmetadata(d::AbstractDict, ctx)
d[ctx]
end

# pirated for Setfield purposes:
using Base: ImmutableDict
Expand Down
Loading
Loading