Skip to content

Commit 5cec229

Browse files
feat: use Moshi.jl instead of Unityper.jl
1 parent 837effd commit 5cec229

File tree

7 files changed

+854
-613
lines changed

7 files changed

+854
-613
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1313
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1414
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1515
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
16+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1617
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
1718
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
20+
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
1921
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
2022
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
23+
ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83"
24+
ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803"
2125
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2226
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2327
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -26,7 +30,6 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2630
TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2731
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2832
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
29-
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
3033
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
3134

3235
[weakdeps]
@@ -47,11 +50,15 @@ ConstructionBase = "1.5.7"
4750
DataStructures = "0.18"
4851
DocStringExtensions = "0.8, 0.9"
4952
DynamicPolynomials = "0.5, 0.6"
53+
EnumX = "1.0.5"
5054
ExproniconLite = "0.10.14"
5155
LabelledArrays = "1.5"
56+
Moshi = "0.3.6"
5257
MultivariatePolynomials = "0.5"
5358
NaNMath = "0.3, 1.1.2"
5459
OhMyThreads = "0.7"
60+
ReadOnlyArrays = "0.2.0"
61+
ReadOnlyDicts = "1.0.0"
5562
ReverseDiff = "1"
5663
RuntimeGeneratedFunctions = "0.5.13"
5764
Setfield = "0.7, 0.8, 1"
@@ -61,7 +68,6 @@ SymbolicIndexingInterface = "0.3"
6168
TaskLocalValues = "0.1.2"
6269
TermInterface = "2.0"
6370
TimerOutputs = "0.5"
64-
Unityper = "0.1.2"
6571
WeakValueDicts = "0.1.0"
6672
julia = "1.10"
6773

src/SymbolicUtils.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ using DocStringExtensions
77

88
export @syms, term, showraw, hasmetadata, getmetadata, setmetadata
99

10-
using Unityper
10+
using Moshi.Data: @data
11+
import Moshi.Data as MData
12+
using Moshi.Match: @match
13+
using ReadOnlyArrays
14+
using ReadOnlyDicts
15+
using EnumX: @enumx
1116
using TermInterface
1217
using DataStructures
1318
using Setfield
@@ -23,6 +28,7 @@ import ArrayInterface
2328
import ExproniconLite as EL
2429
import TaskLocalValues: TaskLocalValue
2530
using WeakValueDicts: WeakValueDict
31+
using Base: RefValue
2632

2733
# include("WeakCacheSets.jl")
2834

src/inspect.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@ function AbstractTrees.nodevalue(x::Symbolic)
55
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
66
end
77

8-
function AbstractTrees.nodevalue(x::BasicSymbolic)
8+
AbstractTrees.nodevalue(x::BasicSymbolic) = AbstractTrees.nodevalue(_unwrap_internal(x))
9+
function AbstractTrees.nodevalue(x::BSImpl.Type)
10+
T = nameof(MData.variant_type(x))
911
str = if !iscall(x)
10-
string(exprtype(x), "(", x, ")")
12+
string(T, "(", x, ")")
1113
elseif isadd(x)
12-
string(exprtype(x),
13-
(scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
14+
string(T,
15+
(variant=string(x.variant), scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
1416
elseif ismul(x)
15-
string(exprtype(x),
16-
(scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
17+
string(T,
18+
(variant=string(x.variant), scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
1719
elseif isdiv(x) || ispow(x)
18-
string(exprtype(x))
20+
string(T)
1921
else
20-
string(exprtype(x),"{", operation(x), "}")
22+
string(T,"{", operation(x), "}")
2123
end
2224

2325
if inspect_metadata[] && !isnothing(metadata(x))

src/methods.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,19 @@ const basic_diadic = [+, -, *, /, //, \, ^]
2929
export SafeReal, LiteralReal
3030

3131
# ideally the relationship should be the other way around
32-
abstract type SafeReal <: Real end
32+
abstract type SafeRealImpl <: Number end
33+
const SafeReal = Union{SafeRealImpl, Real}
34+
Base.one(::Type{SafeReal}) = true
35+
Base.zero(::Type{SafeReal}) = false
36+
Base.convert(::Type{<:SafeRealImpl}, x::Number) = convert(Real, x)
3337

3438
################### LiteralReal #######################
3539

36-
abstract type LiteralReal <: Real end
40+
abstract type LiteralRealImpl <: Number end
41+
const LiteralReal = Union{LiteralRealImpl, Real}
42+
Base.one(::Type{LiteralReal}) = true
43+
Base.zero(::Type{LiteralReal}) = false
44+
Base.convert(::Type{<:LiteralRealImpl}, x::Number) = convert(Real, x)
3745

3846
#######################################################
3947

@@ -101,13 +109,13 @@ macro number_methods(T, rhs1, rhs2, options=nothing)
101109
end
102110

103111
@number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics)
104-
@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
112+
@number_methods(BasicSymbolic{LiteralReal}, term(f, a), term(f, a, b), onlybasics)
105113

106114
for f in vcat(diadic, [+, -, *, \, /, ^])
107115
@eval promote_symtype(::$(typeof(f)),
108116
T::Type{<:Number},
109117
S::Type{<:Number}) = promote_type(T, S)
110-
for R in [SafeReal, LiteralReal]
118+
for R in [SafeRealImpl, LiteralRealImpl]
111119
@eval function promote_symtype(::$(typeof(f)),
112120
T::Type{<:$R},
113121
S::Type{<:Real})
@@ -153,8 +161,8 @@ end
153161
promote_symtype(::Any, T) = promote_type(T, Real)
154162
for f in monadic
155163
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real)
156-
@eval promote_symtype(::$(typeof(f)), T::Type{<:SafeReal}) = SafeReal
157-
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralReal}) = LiteralReal
164+
@eval promote_symtype(::$(typeof(f)), T::Type{<:SafeRealImpl}) = SafeReal
165+
@eval promote_symtype(::$(typeof(f)), T::Type{<:LiteralRealImpl}) = LiteralReal
158166
end
159167

160168
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)

src/polyform.jl

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ function TermInterface.arguments(x::PolyForm{T}) where {T}
213213
m = MP.monomial(t)
214214

215215
if !isone(c)
216-
[c, (unstable_pow(resolve(v), pow)
216+
[c, (^(resolve(v), pow)
217217
for (v, pow) in MP.powers(m) if !iszero(pow))...]
218218
else
219-
[unstable_pow(resolve(v), pow)
219+
[^(resolve(v), pow)
220220
for (v, pow) in MP.powers(m) if !iszero(pow)]
221221
end
222222
elseif MP.nterms(x.p) == 0
@@ -283,7 +283,7 @@ function simplify_div(d)
283283
if all(_isone, ds)
284284
return isempty(ns) ? 1 : simplify_fractions(_mul(ns))
285285
else
286-
Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds)))
286+
Div(simplify_fractions(_mul(ns)), simplify_fractions(_mul(ds)), false)
287287
end
288288
end
289289

@@ -343,15 +343,20 @@ function simplify_fractions(x; polyform=false)
343343
end
344344

345345
function add_with_div(x, flatten=true)
346-
(!iscall(x) || operation(x) != (+)) && return x
346+
(!iscall(x) || operation(x) != (+)) && return nothing
347347
aa = arguments(x)
348-
!any(a->isdiv(a), aa) && return x # no rewrite necessary
348+
!any(a->isdiv(a), aa) && return nothing # no rewrite necessary
349349

350-
divs = filter(a->isdiv(a), aa)
351-
nondivs = filter(a->!(isdiv(a)), aa)
352-
nds = isempty(nondivs) ? 0 : +(nondivs...)
353-
d = reduce(quick_canceladd_divs, divs)
354-
flatten ? quick_cancel(add_divs(d, nds)) : d + nds
350+
nondiv_result = 0
351+
div_result = 0
352+
for a in aa
353+
if isdiv(a)
354+
div_result = quick_cancel(add_divs(div_result, a))
355+
else
356+
nondiv_result += a
357+
end
358+
end
359+
flatten ? quick_cancel(add_divs(div_result, nondiv_result)) : div_result + nondiv_result
355360
end
356361
"""
357362
flatten_fractions(x)
@@ -364,7 +369,7 @@ julia> flatten_fractions((1+(1+1/a)/a)/a)
364369
```
365370
"""
366371
function flatten_fractions(x)
367-
Fixpoint(Postwalk(add_with_div))(x)
372+
Fixpoint(Postwalk(PassThrough(add_with_div)))(x)
368373
end
369374

370375
function fraction_iszero(x)
@@ -419,7 +424,7 @@ function quick_cancel(d)
419424
return prod(arguments(d))
420425
elseif isdiv(d)
421426
num, den = quick_cancel(d.num, d.den)
422-
return Div(num, den)
427+
return Div(num, den, false)
423428
else
424429
return d
425430
end
@@ -471,7 +476,7 @@ end
471476
# ismul(x)
472477
function quick_mul(x, y)
473478
if haskey(x.dict, y) && x.dict[y] >= 1
474-
d = copy(x.dict)
479+
d = copy(parent(x.dict))
475480
if d[y] > 1
476481
d[y] -= 1
477482
elseif d[y] == 1
@@ -490,7 +495,7 @@ end
490495
function quick_mulpow(x, y)
491496
y.exp isa Number || return (x, y)
492497
if haskey(x.dict, y.base)
493-
d = copy(x.dict)
498+
d = copy(parent(x.dict))
494499
if x.dict[y.base] > y.exp
495500
d[y.base] -= y.exp
496501
den = 1
@@ -514,6 +519,12 @@ function quick_mulmul(x, y)
514519
end
515520

516521
function _merge_div(ndict, ddict)
522+
if ndict isa ReadOnlyDict
523+
ndict = parent(ndict)
524+
end
525+
if ddict isa ReadOnlyDict
526+
ddict = parent(ddict)
527+
end
517528
num = copy(ndict)
518529
den = copy(ddict)
519530
for (k, v) in den

0 commit comments

Comments
 (0)