Skip to content

[WIP] refactor: move to Moshi #754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
460b647
feat: use Moshi.jl instead of Unityper.jl
AayushSabharwal Jun 27, 2025
15eed8f
test: update tests
AayushSabharwal Jun 27, 2025
051837e
refactor: prefer `map`ing over `parent(arguments(x))`
AayushSabharwal Jul 2, 2025
225af42
TEMP COMMIT: add bench
AayushSabharwal Jul 2, 2025
cc02e9d
refactor: remove `@timer`
AayushSabharwal Jul 4, 2025
ed8d81f
feat: implement `map` for `SmallVec`
AayushSabharwal Jul 4, 2025
2b95de7
refactor: optimize `ACRule`
AayushSabharwal Jul 4, 2025
38a33a1
feat: add generated `Chain` method
AayushSabharwal Jul 4, 2025
32959fc
refactor: use optimized `Chain` in `simplify` rules
AayushSabharwal Jul 4, 2025
d5da477
refactor: optimize `term_matcher_contructor` a bit
AayushSabharwal Jul 4, 2025
b3aeca0
TEMP COMMIT add deved packages to sources
AayushSabharwal Jul 4, 2025
99a52f7
ci: don't fail-fast in tests
AayushSabharwal Jul 4, 2025
e8709ce
refactor: chain can be massive so make it mutable
AayushSabharwal Jul 4, 2025
eef10a4
fix: make `simplify` constprop better
AayushSabharwal Jul 4, 2025
2923a38
ci: make `polyform/isone` benchmark actually do something
AayushSabharwal Jul 4, 2025
ba39e65
fix: fix cache macro
AayushSabharwal Jul 5, 2025
770e9ca
fix: fix CSE
AayushSabharwal Jul 5, 2025
a884109
refactor: use `mul_worker` instead of `reduce(*, ...)`
AayushSabharwal Jul 5, 2025
6203b2e
refactor: make temporarily disabling hashconsing safer
AayushSabharwal Jul 5, 2025
7cdce5b
fix: handle complex numbers properly in `get_degrees`
AayushSabharwal Jul 7, 2025
b9e62bd
fix: better handle integer demotion in `maybe_integer`
AayushSabharwal Jul 7, 2025
7304427
fix: float-ify in `polyize`
AayushSabharwal Jul 7, 2025
bf27722
HACK: or heresy?
AayushSabharwal Jul 7, 2025
fbf7315
fixup! test: update tests
AayushSabharwal Jul 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
version:
- 'min'
- '1'
fail-fast: false
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
17 changes: 15 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -26,7 +31,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakCacheSets = "d30d5f5c-d141-4870-aa07-aabb0f5fe7d5"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
Expand All @@ -47,11 +52,16 @@ ConstructionBase = "1.5.7"
DataStructures = "0.18"
DocStringExtensions = "0.8, 0.9"
DynamicPolynomials = "0.5, 0.6"
EnumX = "1.0.5"
ExproniconLite = "0.10.14"
LabelledArrays = "1.5"
MacroTools = "0.5.16"
Moshi = "0.3.6"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1.1.2"
OhMyThreads = "0.7"
ReadOnlyArrays = "0.2.0"
ReadOnlyDicts = "1.0.0"
ReverseDiff = "1"
RuntimeGeneratedFunctions = "0.5.13"
Setfield = "0.7, 0.8, 1"
Expand All @@ -61,7 +71,6 @@ SymbolicIndexingInterface = "0.3"
TaskLocalValues = "0.1.2"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.10"

Expand All @@ -82,3 +91,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "SafeTestsets", "Test", "Zygote", "OhMyThreads", "RuntimeGeneratedFunctions"]

[sources]
WeakCacheSets = {url="https://github.com/JuliaCollections/WeakCacheSets.jl"}
Moshi = {url="https://github.com/AayushSabharwal/Moshi.jl", rev="as/mutable-adt"}
9 changes: 9 additions & 0 deletions bench.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using SymbolicUtils, BenchmarkTools

@syms a b c d e f g h i
ex = (f + ((((g*(c^2)*(e^2)) / d - e*h*(c^2)) / b + (-c*e*f*g) / d + c*e*i) /
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d) - c*e) / b +
((g*(f^2)) / d + ((-c*e*f*g) / d + c*f*h) / b - f*i) /
(i + ((c*e*g) / d - c*h) / b + (-f*g) / d)) / d

@benchmark SymbolicUtils.fraction_iszero($ex)
4 changes: 4 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
[deps]

[sources]
WeakCacheSets = {url="https://github.com/JuliaCollections/WeakCacheSets.jl"}
Moshi = {url="https://github.com/AayushSabharwal/Moshi.jl", rev="as/mutable-adt"}
4 changes: 3 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ let
(-f*(g + (-d*g) / d)) / (i + (-c*(h + (-e*g) / d)) / b + (-f*g) / d)) / d
pform["simplify_fractions"] = @benchmarkable simplify_fractions($ex)
pform["iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($ex)
pform["isone"] = @benchmarkable SymbolicUtils.fraction_isone($o)
pform["isone"] = @benchmarkable SymbolicUtils.fraction_isone($ex)
pform["isone:noop"] = @benchmarkable SymbolicUtils.fraction_isone($o)
pform["iszero:noop"] = @benchmarkable SymbolicUtils.fraction_iszero($o)
pform["easy_iszero"] = @benchmarkable SymbolicUtils.fraction_iszero($((b*(h + (-e*g) / d)) / b + (e*g) / d - h))
end
12 changes: 6 additions & 6 deletions docs/src/manual/rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ r1 = @rule sin(2(~x)) => 2sin(~x)*cos(~x)
r1(sin(2z))

# output
2sin(z)*cos(z)
2cos(z)*sin(z)
```

The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent_ (`@rule matcher => consequent`). If an expression matches the matcher pattern, it is rewritten to the consequent pattern. `@rule` returns a callable object that applies the rule to an expression.
Expand All @@ -41,7 +41,7 @@ Slot variable (matcher) is not necessary a single variable
r1(sin(2*(w-z)))

# output
2cos(w - z)*sin(w - z)
2sin(w - z)*cos(w - z)
```

but it must be a single expression
Expand All @@ -61,7 +61,7 @@ r2 = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);
r2(sin(α+β))

# output
sin(β)*cos(α) + cos(β)*sin(α)
cos(β)*sin(α) + sin(β)*cos(α)
```

If you want to match a variable number of subexpressions at once, you will need a **segment variable**. `~~xs` in the following example is a segment variable:
Expand All @@ -71,10 +71,10 @@ If you want to match a variable number of subexpressions at once, you will need
@rule(+(~~xs) => ~~xs)(x + y + z)

# output
3-element view(::SymbolicUtils.SmallVec{Any, Vector{Any}}, 1:3) with eltype Any:
z
y
3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, SymbolicUtils.SmallVec{Any, Vector{Any}}}, 1:3) with eltype Any:
x
y
z
```

`~~xs` is a vector of subexpressions matched. You can use it to construct something more useful:
Expand Down
10 changes: 9 additions & 1 deletion src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ using DocStringExtensions

export @syms, term, showraw, hasmetadata, getmetadata, setmetadata

using Unityper
using Moshi.Data: @data
import Moshi.Data as MData
using Moshi.Match: @match
using ReadOnlyArrays
using ReadOnlyDicts
using EnumX: @enumx
using TermInterface
using DataStructures
using Setfield
Expand All @@ -23,6 +28,9 @@ import ArrayInterface
import ExproniconLite as EL
import TaskLocalValues: TaskLocalValue
using WeakValueDicts: WeakValueDict
using WeakCacheSets: WeakCacheSet, getkey!
using Base: RefValue
import MacroTools

# include("WeakCacheSets.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The key stored in the cache for a particular value. Returns a `SymbolicKey` for
# can't dispatch because `BasicSymbolic` isn't defined here
function get_cache_key(x)
if x isa BasicSymbolic
id = x.id[]
id = x.id
if id === nothing
return CacheSentinel()
end
Expand Down
7 changes: 4 additions & 3 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ struct CSEState
"""
A mapping of symbolic expression to the LHS in `sorted_exprs` that computes it.
"""
visited::IdDict{Any, Any}
visited::IdDict{Union{SymbolicUtils.IDType, AbstractArray, Tuple}, BasicSymbolic}
"""
Integer counter, used to generate unique names for intermediate variables.
"""
Expand Down Expand Up @@ -870,8 +870,9 @@ function cse! end

indextype(::AbstractSparseArray{Tv, Ti}) where {Tv, Ti} = Ti

function cse!(expr::Symbolic, state::CSEState)
get!(state.visited, expr) do

function cse!(expr::BasicSymbolic, state::CSEState)
get!(state.visited, expr.id) do
iscall(expr) || return expr

op = operation(expr)
Expand Down
17 changes: 9 additions & 8 deletions src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ function AbstractTrees.nodevalue(x::Symbolic)
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
end

function AbstractTrees.nodevalue(x::BasicSymbolic)
function AbstractTrees.nodevalue(x::BSImpl.Type)
T = nameof(MData.variant_type(x))
str = if !iscall(x)
string(exprtype(x), "(", x, ")")
string(T, "(", x, ")")
elseif isadd(x)
string(exprtype(x),
(scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
string(T,
(variant=string(x.variant), scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
elseif ismul(x)
string(exprtype(x),
(scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
string(T,
(variant=string(x.variant), scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
elseif isdiv(x) || ispow(x)
string(exprtype(x))
string(T)
else
string(exprtype(x),"{", operation(x), "}")
string(T,"{", operation(x), "}")
end

if inspect_metadata[] && !isnothing(metadata(x))
Expand Down
53 changes: 29 additions & 24 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ function matcher(val::Any)
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
if iscall(val)
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
args = parent(arguments(val))
if length(args) == 2 && any(x -> isa(x, DefSlot), args)
return defslot_term_matcher_constructor(val)
# else return a normal term matcher
else
Expand All @@ -35,7 +36,9 @@ function matcher(slot::Slot)
end
# elseif the first element of data matches the slot predicate, add it to bindings and call next
elseif slot.predicate(car(data))
next(assoc(bindings, slot.name, car(data)), 1)
rest = car(data)
binds = assoc(bindings, slot.name, rest)
next(binds, 1)
end
end
end
Expand Down Expand Up @@ -104,32 +107,34 @@ function matcher(segment::Segment)
end

function term_matcher_constructor(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
matchers = vcat([matcher(operation(term))], map(matcher, parent(arguments(term))))

function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
let matchers = matchers
function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
# explenation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end
# explenation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

loop(car(data), bindings, matchers) # Try to eat exactly one term
loop(car(data), bindings, matchers) # Try to eat exactly one term
end
end
end

Expand All @@ -146,7 +151,7 @@ end
# calls the success function like term_matcher would do

function defslot_term_matcher_constructor(term)
a = arguments(term) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
a = parent(arguments(term)) # lenght two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term

defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
Expand Down
30 changes: 16 additions & 14 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@ const previously_declared_for = Set([])

const basic_monadic = [-, +]
const basic_diadic = [+, -, *, /, //, \, ^]
#################### SafeReal #########################
export SafeReal, LiteralReal

# ideally the relationship should be the other way around
abstract type SafeReal <: Real end

################### LiteralReal #######################

abstract type LiteralReal <: Real end

#######################################################

assert_like(f, T) = nothing
Expand Down Expand Up @@ -101,13 +91,25 @@ macro number_methods(T, rhs1, rhs2, options=nothing)
end

@number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics)
@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
@number_methods(BasicSymbolic{LiteralReal}, term(f, a), term(f, a, b), onlybasics)

for f in vcat(diadic, [+, -, *, \, /, ^])
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Number},
S::Type{<:Number}) = promote_type(T, S)
for R in [SafeReal, LiteralReal]
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Rational},
S::Type{Integer}) = Rational
@eval promote_symtype(::$(typeof(f)),
T::Type{Integer},
S::Type{<:Rational}) = Rational
@eval promote_symtype(::$(typeof(f)),
T::Type{<:Complex{<:Rational}},
S::Type{Integer}) = Complex{Rational}
@eval promote_symtype(::$(typeof(f)),
T::Type{Integer},
S::Type{<:Complex{<:Rational}}) = Complex{Rational}
for R in [SafeRealImpl, LiteralRealImpl]
@eval function promote_symtype(::$(typeof(f)),
T::Type{<:$R},
S::Type{<:Real})
Expand Down Expand Up @@ -153,8 +155,8 @@ end
promote_symtype(::Any, T) = promote_type(T, Real)
for f in monadic
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real)
@eval promote_symtype(::$(typeof(f)), T::Type{<:SafeReal}) = SafeReal
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralReal}) = LiteralReal
@eval promote_symtype(::$(typeof(f)), T::Type{<:SafeRealImpl}) = SafeReal
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralRealImpl}) = LiteralReal
end

Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
Expand Down
2 changes: 1 addition & 1 deletion src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end

function _get_degrees(::typeof(^), expr, degs_cache)
base_expr, pow_expr = arguments(expr)
if pow_expr isa Number
if pow_expr isa Real
@inbounds degs = map(_get_degrees(base_expr, degs_cache)) do (base, pow)
(base => pow * pow_expr)
end
Expand Down
Loading
Loading