From 6cddb1ffb942a1e2b3249637cd91ba994dfb9090 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:16:29 +0530 Subject: [PATCH 1/6] build: add `ReadOnlyArrays`, `ReadOnlyDicts` --- Project.toml | 4 ++++ src/SymbolicUtils.jl | 2 ++ 2 files changed, 6 insertions(+) diff --git a/Project.toml b/Project.toml index cbc04c758..3ec83ad17 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,8 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +ReadOnlyArrays = "988b38a3-91fc-5605-94a2-ee2116b3bd83" +ReadOnlyDicts = "795d4caa-f5a7-4580-b5d8-c01d53451803" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -52,6 +54,8 @@ LabelledArrays = "1.5" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1.1.2" OhMyThreads = "0.7" +ReadOnlyArrays = "0.2.0" +ReadOnlyDicts = "1.0.0" ReverseDiff = "1" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index f965be206..0bfc30a8c 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -23,6 +23,8 @@ import ArrayInterface using WeakValueDicts: WeakValueDict import ExproniconLite as EL import TaskLocalValues: TaskLocalValue +using ReadOnlyArrays +using ReadOnlyDicts include("cache.jl") Base.@deprecate istree iscall From 1d663ebdbbb959fa46bda3c40f792c05f673e0ff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:16:48 +0530 Subject: [PATCH 2/6] refactor: store `dict` as `ReadOnlyDict` --- src/types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types.jl b/src/types.jl index 152a97290..35db9d497 100644 --- a/src/types.jl +++ b/src/types.jl @@ -18,7 +18,7 @@ sdict(kv...) = Dict{Any, Any}(kv...) using Base: RefValue const EMPTY_ARGS = [] const EMPTY_HASH = RefValue(UInt(0)) -const EMPTY_DICT = sdict() +const EMPTY_DICT = ReadOnlyDict(sdict()) const EMPTY_DICT_T = typeof(EMPTY_DICT) const ENABLE_HASHCONSING = Ref(true) From 08860742b1e086f172dc9da2ec9759a57fe62be0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:17:02 +0530 Subject: [PATCH 3/6] refactor: better error message for `setproperty!` --- src/types.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/types.jl b/src/types.jl index 35db9d497..852833fb2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -62,6 +62,10 @@ const ENABLE_HASHCONSING = Ref(true) end end +function Base.setproperty!(x::BasicSymbolic, sym::Symbol, v) + error("Mutating `BasicSymbolic` is not allowed") +end + function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) ScalarSymbolic() end From f4a69770060770c7248442d08b87cf571dc9a210 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:17:25 +0530 Subject: [PATCH 4/6] refactor: return `ReadOnlyVector` from `arguments`/`sorted_arguments` --- src/types.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/types.jl b/src/types.jl index 852833fb2..1f25c073a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -151,8 +151,8 @@ end @inline head(x::BasicSymbolic) = operation(x) -@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::Vector{Any} - args = copy(arguments(x)) +@cache function TermInterface.sorted_arguments(x::BasicSymbolic)::ReadOnlyVector{Any} + args = copy(parent(arguments(x))) @compactified x::BasicSymbolic begin Add => @goto ADD Mul => @goto MUL @@ -171,7 +171,7 @@ end TermInterface.children(x::BasicSymbolic) = arguments(x) TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) -function TermInterface.arguments(x::BasicSymbolic) +function TermInterface.arguments(x::BasicSymbolic)::ReadOnlyVector{Any} @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL From 1be91d113bfcc384e28e71894ad58d9143d6b030 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:17:41 +0530 Subject: [PATCH 5/6] refactor: remove `unwrap_arr!`, `unwrap_dict` --- src/types.jl | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/types.jl b/src/types.jl index 1f25c073a..7cae07140 100644 --- a/src/types.jl +++ b/src/types.jl @@ -558,17 +558,10 @@ function Sym{T}(name::Symbol; kw...) where {T} BasicSymbolic(s) end -function unwrap_arr!(arr) - for i in eachindex(arr) - arr[i] = unwrap(arr[i]) - end -end - function Term{T}(f, args; kw...) where T if eltype(args) !== Any args = convert(Vector{Any}, args) end - unwrap_arr!(args) s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...) BasicSymbolic(s) @@ -578,16 +571,8 @@ function Term(f, args; metadata=NO_METADATA) Term{_promote_symtype(f, args)}(f, args, metadata=metadata) end -function unwrap_dict(dict) - if any(k -> unwrap(k) !== k, keys(dict)) - return typeof(dict)(unwrap(k) => v for (k, v) in dict) - end - return dict -end - function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T coeff = unwrap(coeff) - dict = unwrap_dict(dict) if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 @@ -606,7 +591,6 @@ end function Mul(T, a, b; metadata=NO_METADATA, kw...) a = unwrap(a) - b = unwrap_dict(b) isempty(b) && return a if _isone(a) && length(b) == 1 pair = first(b) From bffa87e5f5194a8fb69c466441187fd664285b2d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 19 Mar 2025 12:17:59 +0530 Subject: [PATCH 6/6] fix: handle new read-only results --- docs/src/manual/rewrite.md | 2 +- src/polyform.jl | 6 +++--- src/types.jl | 4 ++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/src/manual/rewrite.md b/docs/src/manual/rewrite.md index 047bef71a..05d18e82a 100644 --- a/docs/src/manual/rewrite.md +++ b/docs/src/manual/rewrite.md @@ -71,7 +71,7 @@ If you want to match a variable number of subexpressions at once, you will need @rule(+(~~xs) => ~~xs)(x + y + z) # output -3-element view(::Vector{Any}, 1:3) with eltype Any: +3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, Vector{Any}}, 1:3) with eltype Any: z y x diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..d41fa2070 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -471,7 +471,7 @@ end # ismul(x) function quick_mul(x, y) if haskey(x.dict, y) && x.dict[y] >= 1 - d = copy(x.dict) + d = copy(parent(x.dict)) if d[y] > 1 d[y] -= 1 elseif d[y] == 1 @@ -490,7 +490,7 @@ end function quick_mulpow(x, y) y.exp isa Number || return (x, y) if haskey(x.dict, y.base) - d = copy(x.dict) + d = copy(parent(x.dict)) if x.dict[y.base] > y.exp d[y.base] -= y.exp den = 1 @@ -509,7 +509,7 @@ end # Double mul case function quick_mulmul(x, y) - num_dict, den_dict = _merge_div(x.dict, y.dict) + num_dict, den_dict = _merge_div(parent(x.dict), parent(y.dict)) Mul(symtype(x), x.coeff, num_dict), Mul(symtype(y), y.coeff, den_dict) end diff --git a/src/types.jl b/src/types.jl index 7cae07140..3b1d23276 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1330,6 +1330,10 @@ function _merge!(f::F, d, others...; filter=x->false) where F acc end +function mapvalues(f, d1::ReadOnlyDict) + mapvalues(f, parent(d1)) +end + function mapvalues(f, d1::AbstractDict) d = copy(d1) for (k, v) in d