Skip to content

fix: remove mutation of BasicSymbolic #1480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function _toexpr(O)
while num isa Term && num.f isa Differential
deg += 1
den *= num.f.x
num = num.arguments[1]
num = first(arguments(num))
end
return :(_derivative($(_toexpr(num)), $den, $deg))
elseif op isa Integral
Expand Down
19 changes: 12 additions & 7 deletions src/solver/ia_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
for i in eachindex(lhs_roots)
for j in eachindex(rhs)
if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf
lhs_roots[i].arguments[1] = substitute(lhs_roots[i].arguments[1], Dict(new_var=>rhs[j]), fold=false)
_args = copy(parent(arguments(lhs_roots[i])))
_args[1] = substitute(_args[1], Dict(new_var => rhs[j]), fold = false)
T = typeof(lhs_roots[i])
_op = operation(lhs_roots[i])
_meta = metadata(lhs_roots[i])
lhs_roots[i] = maketerm(T, _op, _args, _meta)
push!(roots, lhs_roots[i])
else
push!(roots, substitute(lhs_roots[i], Dict(new_var=>rhs[j]), fold=false))
Expand Down Expand Up @@ -86,8 +91,9 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
end

elseif oper === (^)
if any(isequal(x, var) for x in get_variables(args[1])) &&
n_occurrences(args[2], var) == 0 && args[2] isa Integer
var_in_base = any(isequal(x, var) for x in get_variables(args[1]))
var_in_pow = n_occurrences(args[2], var) != 0
if var_in_base && !var_in_pow && args[2] isa Integer
lhs = args[1]
power = args[2]
new_roots = []
Expand All @@ -111,11 +117,10 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri
end
rhs = []
append!(rhs, new_roots)
elseif any(isequal(x, var) for x in get_variables(args[1])) &&
n_occurrences(args[2], var) == 0
elseif var_in_base && !var_in_pow
lhs = args[1]
s, args[2] = filter_stuff(args[2])
rhs = map(sol -> term(^, sol, 1 // args[2]), rhs)
s, power = filter_stuff(args[2])
rhs = map(sol -> term(^, sol, 1 // power), rhs)
else
lhs = args[2]
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)
Expand Down
53 changes: 27 additions & 26 deletions src/solver/polynomialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function turn_to_poly(expr, var)
expr = unwrap(expr)
!iscall(expr) && return (expr, Dict())

args = arguments(expr)
args = copy(parent(arguments(expr)))

sub = 0
broken = Ref(false)
Expand All @@ -53,12 +53,12 @@ function turn_to_poly(expr, var)
arg_oper = operation(arg)

if arg_oper === (^)
tp = trav_pow(args, i, var, broken, sub)
args[i], tp = trav_pow(args[i], var, broken, sub)
sub = isequal(tp, false) ? sub : tp
continue
end
if arg_oper === (*)
sub = trav_mult(arg, var, broken, sub)
args[i], sub = trav_mult(arg, var, broken, sub)
continue
end
isequal(add_sub(sub, arg, var, broken), false) && continue
Expand All @@ -77,16 +77,17 @@ function turn_to_poly(expr, var)

new_var = gensym()
new_var = (@variables $new_var)[1]
expr = maketerm(typeof(expr), operation(expr), args, metadata(expr))
return ssubs(expr, Dict(sub => new_var)), Dict{Any, Any}(new_var => sub)
end

"""
trav_pow(args, index, var, broken, sub)
trav_pow(arg, var, broken, sub)

Traverses an argument passed from ``turn_to_poly`` if it
satisfies ``oper === (^)``. Returns sub if changed from 0
to a new transcendental function or its value is
kept the same, and false if these 2 cases do not occur.
Traverses an argument `arg` passed from ``turn_to_poly`` if it satisfies
``oper === (^)``. Returns the new `arg` and `sub` if `sub` is changed from 0 to a new
transcendental function or its value is kept the same, or else `false` if these 2 cases
do not occur.

# Arguments
- args: The original arguments array of the expression passed to ``turn_to_poly``
Expand All @@ -97,20 +98,20 @@ kept the same, and false if these 2 cases do not occur.

# Examples
```jldoctest
julia> trav_pow([unwrap(9^x)], 1, x, Ref(false), 3^x)
3^x
julia> trav_pow(unwrap(9^x), x, Ref(false), 3^x)
(9^x, 3^x)

julia> trav_pow([unwrap(x^2)], 1, x, Ref(false), 3^x)
false
julia> trav_pow(unwrap(x^2), x, Ref(false), 3^x)
(x^2, false)
```
"""
function trav_pow(args, index, var, broken, sub)
args_arg = arguments(args[index])
function trav_pow(arg, var, broken, sub)
args_arg = arguments(arg)
base = args_arg[1]
power = args_arg[2]

# case 1: log(x)^2 .... 9^x = 3^2^x = 3^2x = (3^x)^2
!isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return base
!isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return arg, base

# case 2: int^f(x)
# n_func_occ may not be strictly 1, we could attempt attracting it after solving
Expand All @@ -122,21 +123,20 @@ function trav_pow(args, index, var, broken, sub)
sub = isequal(sub, 0) ? new_b : sub
if !isequal(sub, new_b)
broken[] = true
return false
return arg, false
end
new_b = term(^, new_b, p)
args[index] = new_b
return sub
return new_b, sub
end

return false
return arg, false
end

"""
trav_mult(arg, var, broken, sub)

Traverses an argument passed from ``turn_to_poly`` if it
satisfies ``oper === (*)``. Returns sub whether its changed from 0
satisfies ``oper === (*)``. Returns the new `arg` and `sub` if its changed from 0
to a new transcendental function or its value is
kept the same, but changes broken if these 2 cases do not occur. It
traverses the * argument by sub_arg and compares it to sub using
Expand All @@ -151,32 +151,33 @@ the function ``add_sub``
# Examples
```jldoctest
julia> trav_mult(unwrap(9*log(x)), x, Ref(false), log(x))
log(x)
(9log(x), log(x))

julia> trav_mult(unwrap(9*log(x)^2), x, Ref(false), log(x))
log(x)
(9(log(x)^2), log(x))

# value of broken is changed here to true
julia> trav_mult(unwrap(9*log(x+1)), x, Ref(false), log(x))
log(x)
(9log(x + 1), log(x))
```
"""
function trav_mult(arg, var, broken, sub)
args_arg = arguments(arg)
args_arg = copy(parent(arguments(arg)))
for (i, arg2) in enumerate(args_arg)
!iscall(arg2) && continue

oper = operation(arg2)
if oper === (^)
tp = trav_pow(args_arg, i, var, broken, sub)
args_arg[i], tp = trav_pow(args_arg[i], var, broken, sub)
sub = isequal(tp, false) ? sub : tp
continue
end

isequal(add_sub(sub, arg2, var, broken), false) && continue
sub = arg2
end
return sub
arg = maketerm(typeof(arg), operation(arg), args_arg, metadata(arg))
return arg, sub
end

"""
Expand Down
10 changes: 6 additions & 4 deletions src/solver/preprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function _filter_poly(expr, var)
return filter_stuff(expr)
end

args = arguments(expr)
args = copy(parent(arguments(expr)))
if expr isa ComplexTerm
subs1, subs2 = Dict(), Dict()
expr1, expr2 = 0, 0
Expand Down Expand Up @@ -165,7 +165,7 @@ function _filter_poly(expr, var)
end

oper = operation(arg)
monomial = arguments(arg)
monomial = copy(parent(arguments(arg)))
if oper === (^)
if any(arg -> isequal(arg, var), monomial)
continue
Expand All @@ -175,6 +175,7 @@ function _filter_poly(expr, var)
subs2, monomial[2] = _filter_poly(monomial[2], var)

merge!(subs, merge(subs1, subs2))
args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg))
continue
end

Expand All @@ -196,6 +197,7 @@ function _filter_poly(expr, var)
merge!(subs_of_monom, new_subs)
end
merge!(subs, subs_of_monom)
args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg))
continue
end

Expand All @@ -208,9 +210,9 @@ function _filter_poly(expr, var)
end
end

args = map(unwrap, arguments(expr))
args = map(unwrap, args)
oper = operation(expr)
expr = term(oper, args...)
expr = maketerm(typeof(expr), oper, args, metadata(expr))
return subs, expr
end

Expand Down
3 changes: 2 additions & 1 deletion src/solver/solve_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,11 @@ function bigify(n)

if n isa SymbolicUtils.BasicSymbolic
!iscall(n) && return n
args = arguments(n)
args = copy(parent(arguments(n)))
for i in eachindex(args)
args[i] = bigify(args[i])
end
n = maketerm(typeof(n), operation(n), args, metadata(n))
return n
end

Expand Down
Loading