From fe3110b85e512a1e21152247b46c336b57330d22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Fri, 16 May 2025 11:17:02 +0200 Subject: [PATCH 1/2] Implement mutable addition --- Project.toml | 2 ++ src/SymbolicUtils.jl | 2 ++ src/mutable_arithmetics.jl | 16 ++++++++++++++++ 3 files changed, 20 insertions(+) create mode 100644 src/mutable_arithmetics.jl diff --git a/Project.toml b/Project.toml index 87ad76cf..afb8eeef 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" +MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -50,6 +51,7 @@ DynamicPolynomials = "0.5, 0.6" ExproniconLite = "0.10.14" LabelledArrays = "1.5" MultivariatePolynomials = "0.5" +MutableArithmetics = "1.6.4" NaNMath = "0.3, 1.1.2" OhMyThreads = "0.7" ReverseDiff = "1" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index 3c930f66..b55adfcb 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -82,5 +82,7 @@ include("code.jl") # Adjoints include("adjoints.jl") +# Mutable Arithmetics +include("mutable_arithmetics.jl") end # module diff --git a/src/mutable_arithmetics.jl b/src/mutable_arithmetics.jl new file mode 100644 index 00000000..1ac58e8d --- /dev/null +++ b/src/mutable_arithmetics.jl @@ -0,0 +1,16 @@ +import MutableArithmetics as MA + +function MA.operate!!(::typeof(+), a::BasicSymbolic, b::BasicSymbolic) + if SymbolicUtils.isadd(a) + if SymbolicUtils.isadd(b) + for (k, v) in b.dict + a.dict[k] = MA.add!!(get(a.dict, k, 0), v) + end + else + a.dict[b] = MA.add!!(get(a.dict, b, 0), 1) + end + return a + else + return a + b + end +end From f6cfdc9de0f9333215cea8eeb5bda038a9fc6457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Fri, 16 May 2025 11:30:37 +0200 Subject: [PATCH 2/2] Add tests --- test/mutable_arithmetics.jl | 17 +++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 18 insertions(+) create mode 100644 test/mutable_arithmetics.jl diff --git a/test/mutable_arithmetics.jl b/test/mutable_arithmetics.jl new file mode 100644 index 00000000..b73a5b59 --- /dev/null +++ b/test/mutable_arithmetics.jl @@ -0,0 +1,17 @@ +using Test +using MutableArithmetics +using SymbolicUtils + +@syms x::Real y::Real +v = repeat([x, y], 10) +s = sum(v, init = 0) +@test s.dict[x] == 10 +@test s.dict[y] == 10 +@test isequal(s, 10x + 10y) # ???? + +a = x + y +b = 2x + y +c = add!!(a, b) +@test c.dict[x] == 3 +@test c.dict[y] == 2 +@test isequal(c, 3x + 2y) diff --git a/test/runtests.jl b/test/runtests.jl index 0d58ad20..884ca893 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,5 +18,6 @@ using Pkg, Test, SafeTestsets @safetestset "Adjoints" begin include("adjoints.jl") end @safetestset "Hash Consing" begin include("hash_consing.jl") end @safetestset "Cache macro" begin include("cache_macro.jl") end + @safetestset "Cache macro" begin include("mutable_arithmetics.jl") end end end