Skip to content

Commit 10b5997

Browse files
feat: add additional keywords to ia_solve
1 parent c8b0fa3 commit 10b5997

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

src/solver/ia_main.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Symbolics
22

33
const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt)
44

5-
function isolate(lhs, var; warns=true, conditions=[])
5+
function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, periodic_roots = true)
66
rhs = Vector{Any}([0])
77
original_lhs = deepcopy(lhs)
88
lhs = unwrap(lhs)
@@ -74,12 +74,21 @@ function isolate(lhs, var; warns=true, conditions=[])
7474
power = args[2]
7575
new_roots = []
7676

77-
for i in eachindex(rhs)
78-
for k in 0:(args[2] - 1)
79-
r = wrap(term(^, rhs[i], (1 // power)))
80-
c = wrap(term(*, 2 * (k), pi)) * im / power
81-
root = r * Base.MathConstants.e^c
82-
push!(new_roots, root)
77+
if complex_roots
78+
for i in eachindex(rhs)
79+
for k in 0:(args[2] - 1)
80+
r = wrap(term(^, rhs[i], (1 // power)))
81+
c = wrap(term(*, 2 * (k), pi)) * im / power
82+
root = r * Base.MathConstants.e^c
83+
push!(new_roots, root)
84+
end
85+
end
86+
else
87+
for i in eachindex(rhs)
88+
push!(new_roots, wrap(term(^, rhs[i], (1 // power))))
89+
if iseven(power)
90+
push!(new_roots, wrap(term(-, new_roots[end])))
91+
end
8392
end
8493
end
8594
rhs = []
@@ -97,7 +106,7 @@ function isolate(lhs, var; warns=true, conditions=[])
97106
ia_conditions!(oper, lhs, rhs, conditions)
98107
invop = left_inverse(oper)
99108
invop = get(SAFE_ALTERNATIVES, invop, invop)
100-
if is_periodic(oper)
109+
if is_periodic(oper) && periodic_roots
101110
# make this global somehow so the user doesnt need to declare it on his own
102111
new_var = gensym()
103112
new_var = (@variables $new_var)[1]
@@ -118,7 +127,7 @@ function isolate(lhs, var; warns=true, conditions=[])
118127
return rhs, conditions
119128
end
120129

121-
function attract(lhs, var; warns = true)
130+
function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
122131
if n_func_occ(simplify(lhs), var) <= n_func_occ(lhs, var)
123132
lhs = simplify(lhs)
124133
end
@@ -133,7 +142,9 @@ function attract(lhs, var; warns = true)
133142
end
134143
lhs = attract_trig(lhs, var)
135144

136-
n_func_occ(lhs, var) == 1 && return isolate(lhs, var, warns = warns, conditions=conditions)
145+
if n_func_occ(lhs, var) == 1
146+
return isolate(lhs, var; warns, conditions, complex_roots, periodic_roots)
147+
end
137148

138149
lhs, sub = turn_to_poly(lhs, var)
139150

@@ -151,12 +162,12 @@ function attract(lhs, var; warns = true)
151162
new_var = collect(keys(sub))[1]
152163
new_var_val = collect(values(sub))[1]
153164

154-
roots, new_conds = isolate(lhs, new_var, warns = warns)
165+
roots, new_conds = isolate(lhs, new_var; warns = warns, complex_roots, periodic_roots)
155166
append!(conditions, new_conds)
156167
new_roots = []
157168

158169
for root in roots
159-
new_sol, new_conds = isolate(new_var_val - root, var, warns = warns)
170+
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
160171
append!(conditions, new_conds)
161172
push!(new_roots, new_sol)
162173
end
@@ -166,7 +177,7 @@ function attract(lhs, var; warns = true)
166177
end
167178

168179
"""
169-
ia_solve(lhs, var)
180+
ia_solve(lhs, var; kwargs...)
170181
This function attempts to solve transcendental functions by first checking
171182
the "smart" number of occurrences in the input LHS. By smart here we mean
172183
that polynomials are counted as 1 occurrence. for example `x^2 + 2x` is 1
@@ -195,6 +206,13 @@ we throw an error to tell the user that this is currently unsolvable by our cove
195206
- lhs: a Num/SymbolicUtils.BasicSymbolic
196207
- var: variable to solve for.
197208
209+
# Keyword arguments
210+
- `warns = true`: Whether to emit warnings for unsolvable expressions.
211+
- `complex_roots = true`: Whether to consider complex roots of `x ^ n ~ y`, where `n` is an integer.
212+
- `periodic_roots = true`: If `true`, isolate `f(x) ~ y` as `x ~ finv(y) + n * period` where
213+
`is_periodic(f) == true`, `finv = left_inverse(f)` and `period = fundamental_period(f)`. `n`
214+
is a new anonymous symbolic variable.
215+
198216
# Examples
199217
```jldoctest
200218
julia> solve(a*x^b + c, x)
@@ -238,17 +256,17 @@ See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref),
238256
# References
239257
[^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070).
240258
"""
241-
function ia_solve(lhs, var; warns = true)
259+
function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
242260
nx = n_func_occ(lhs, var)
243261
sols = []
244262
conditions = []
245263
if nx == 0
246264
warns && @warn("Var not present in given expression")
247265
return []
248266
elseif nx == 1
249-
sols, conditions = isolate(lhs, var, warns = warns)
267+
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
250268
elseif nx > 1
251-
sols, conditions = attract(lhs, var, warns = warns)
269+
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
252270
end
253271

254272
isequal(sols, nothing) && return nothing

0 commit comments

Comments
 (0)