Skip to content

Commit 8c518c2

Browse files
authored
Merge pull request #1348 from AayushSabharwal/as/ia-solve-inverses
2 parents 8998fa8 + 1d71dd4 commit 8c518c2

File tree

5 files changed

+183
-67
lines changed

5 files changed

+183
-67
lines changed

docs/src/manual/solver.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ to `solve_univar`. We can see that essentially, `solve_univar` is the building b
3232
it to `ia_solve`, which attempts solving by attraction and isolation [^2]. This only works when the input is a single expression
3333
and the user wants the answer in terms of a single variable. Say `log(x) - a == 0` gives us `[e^a]`.
3434

35+
```@docs
36+
Symbolics.solve_univar
37+
Symbolics.solve_multivar
38+
Symbolics.ia_solve
39+
Symbolics.ia_conditions!
40+
Symbolics.is_periodic
41+
Symbolics.fundamental_period
42+
```
43+
3544
#### Nice examples
3645

3746
```@example solver

src/inverse.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ inverse(::typeof(NaNMath.log10)) = inverse(log10)
157157
inverse(::typeof(NaNMath.log1p)) = inverse(log1p)
158158
inverse(::typeof(NaNMath.log2)) = inverse(log2)
159159
left_inverse(::typeof(NaNMath.sqrt)) = left_inverse(sqrt)
160+
# inverses of solve helpers
161+
left_inverse(::typeof(ssqrt)) = left_inverse(sqrt)
162+
left_inverse(::typeof(scbrt)) = left_inverse(cbrt)
163+
left_inverse(::typeof(slog)) = left_inverse(log)
160164

161165
function inverse(f::ComposedFunction)
162166
return inverse(f.inner) inverse(f.outer)

src/solver/ia_helpers.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,85 @@ function find_logandexpon(arg, var, oper, poly_index)
140140
!isequal(oper_term, 0) && !isequal(constant_term, 0) && return true
141141
return false
142142
end
143+
144+
"""
145+
ia_conditions!(f, lhs, rhs::Vector{Any}, conditions::Vector{Tuple})
146+
147+
If `f` is a left-invertible function, `lhs` and `rhs[i]` are univariate functions and
148+
`f(lhs) ~ rhs[i]` for all `i in eachindex(rhss)`, push to `conditions` all the relevant
149+
conditions on `lhs` or `rhs[i]`. Each condition is of the form `(sym, op)` where `sym`
150+
is an expression involving `lhs` and/or `rhs[i]` and `op` is a binary relational operator.
151+
The condition `op(sym, 0)` is then required to be true for the equation `f(lhs) ~ rhs[i]`
152+
to be valid.
153+
154+
For example, if `f = log`, `lhs = x` and `rhss = [y, z]` then the condition `x > 0` must
155+
be true. Thus, `(lhs, >)` is pushed to `conditions`. Similarly, if `f = sqrt`, `rhs[i] >= 0`
156+
must be true for all `i`, and so `(y, >=)` and `(z, >=)` will be appended to `conditions`.
157+
"""
158+
function ia_conditions!(args...; kwargs...) end
159+
160+
for fn in [log, log2, log10, NaNMath.log, NaNMath.log2, NaNMath.log10, slog]
161+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
162+
push!(conditions, (lhs, >))
163+
end
164+
end
165+
166+
for fn in [log1p, NaNMath.log1p]
167+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
168+
push!(conditions, (lhs - 1, >))
169+
end
170+
end
171+
172+
for fn in [sqrt, NaNMath.sqrt, ssqrt]
173+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
174+
for r in rhs
175+
push!(conditions, (r, >=))
176+
end
177+
end
178+
end
179+
180+
"""
181+
is_periodic(f)
182+
183+
Return `true` if `f` is a single-input single-output periodic function. Return `false` by
184+
default. If `is_periodic(f) == true`, then `fundamental_period(f)` must also be defined.
185+
186+
See also: [`fundamental_period`](@ref)
187+
"""
188+
is_periodic(f) = false
189+
190+
for fn in [
191+
sin, cos, tan, csc, sec, cot, NaNMath.sin, NaNMath.cos, NaNMath.tan, sind, cosd, tand,
192+
cscd, secd, cotd, cospi
193+
]
194+
@eval is_periodic(::typeof($fn)) = true
195+
end
196+
197+
"""
198+
fundamental_period(f)
199+
200+
Return the fundamental period of periodic function `f`. Must only be called if
201+
`is_periodic(f) == true`.
202+
203+
see also: [`is_periodic`](@ref)
204+
"""
205+
function fundamental_period end
206+
207+
for fn in [sin, cos, csc, sec, NaNMath.sin, NaNMath.cos]
208+
@eval fundamental_period(::typeof($fn)) = 2pi
209+
end
210+
211+
for fn in [sind, cosd, cscd, secd]
212+
@eval fundamental_period(::typeof($fn)) = 360.0
213+
end
214+
215+
fundamental_period(::typeof(cospi)) = 2.0
216+
217+
for fn in [tand, cotd]
218+
@eval fundamental_period(::typeof($fn)) = 180.0
219+
end
220+
221+
for fn in [tan, cot, NaNMath.tan]
222+
# `1pi isa Float64` whereas `pi isa Irrational{:π}`
223+
@eval fundamental_period(::typeof($fn)) = 1pi
224+
end

src/solver/ia_main.jl

Lines changed: 61 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Symbolics
22

3-
function isolate(lhs, var; warns=true, conditions=[])
3+
const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt)
4+
5+
function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, periodic_roots = true)
46
rhs = Vector{Any}([0])
57
original_lhs = deepcopy(lhs)
68
lhs = unwrap(lhs)
@@ -72,12 +74,21 @@ function isolate(lhs, var; warns=true, conditions=[])
7274
power = args[2]
7375
new_roots = []
7476

75-
for i in eachindex(rhs)
76-
for k in 0:(args[2] - 1)
77-
r = wrap(term(^, rhs[i], (1 // power)))
78-
c = wrap(term(*, 2 * (k), pi)) * im / power
79-
root = r * Base.MathConstants.e^c
80-
push!(new_roots, root)
77+
if complex_roots
78+
for i in eachindex(rhs)
79+
for k in 0:(args[2] - 1)
80+
r = term(^, rhs[i], (1 // power))
81+
c = term(*, 2 * (k), pi) * im / power
82+
root = r * Base.MathConstants.e^c
83+
push!(new_roots, root)
84+
end
85+
end
86+
else
87+
for i in eachindex(rhs)
88+
push!(new_roots, term(^, rhs[i], (1 // power)))
89+
if iseven(power)
90+
push!(new_roots, term(-, new_roots[end]))
91+
end
8192
end
8293
end
8394
rhs = []
@@ -90,57 +101,23 @@ function isolate(lhs, var; warns=true, conditions=[])
90101
lhs = args[2]
91102
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)
92103
end
93-
94-
elseif oper === (log) || oper === (slog)
95-
lhs = args[1]
96-
rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs)
97-
push!(conditions, (args[1], >))
98-
99-
elseif oper === (log2)
100-
lhs = args[1]
101-
rhs = map(sol -> term(^, 2, sol), rhs)
102-
push!(conditions, (args[1], >))
103-
104-
elseif oper === (log10)
104+
elseif has_left_inverse(oper)
105105
lhs = args[1]
106-
rhs = map(sol -> term(^, 10, sol), rhs)
107-
push!(conditions, (args[1], >))
108-
109-
elseif oper === (sqrt)
110-
lhs = args[1]
111-
append!(conditions, [(r, >=) for r in rhs])
112-
rhs = map(sol -> term(^, sol, 2), rhs)
113-
114-
elseif oper === (cbrt)
115-
lhs = args[1]
116-
rhs = map(sol -> term(^, sol, 3), rhs)
117-
118-
elseif oper === (sin) || oper === (cos) || oper === (tan)
119-
rev_oper = Dict(sin => asin, cos => acos, tan => atan)
120-
lhs = args[1]
121-
# make this global somehow so the user doesnt need to declare it on his own
122-
new_var = gensym()
123-
new_var = (@variables $new_var)[1]
124-
rhs = map(
125-
sol -> term(rev_oper[oper], sol) +
126-
term(*, Base.MathConstants.pi, new_var),
127-
rhs)
128-
@info string(new_var) * " ϵ" * " Ζ"
129-
130-
elseif oper === (asin)
131-
lhs = args[1]
132-
rhs = map(sol -> term(sin, sol), rhs)
133-
134-
elseif oper === (acos)
135-
lhs = args[1]
136-
rhs = map(sol -> term(cos, sol), rhs)
137-
138-
elseif oper === (atan)
139-
lhs = args[1]
140-
rhs = map(sol -> term(tan, sol), rhs)
141-
elseif oper === (exp)
142-
lhs = args[1]
143-
rhs = map(sol -> term(slog, sol), rhs)
106+
ia_conditions!(oper, lhs, rhs, conditions)
107+
invop = left_inverse(oper)
108+
invop = get(SAFE_ALTERNATIVES, invop, invop)
109+
if is_periodic(oper) && periodic_roots
110+
new_var = gensym()
111+
new_var = (@variables $new_var)[1]
112+
period = fundamental_period(oper)
113+
rhs = map(
114+
sol -> term(invop, sol) +
115+
term(*, period, new_var),
116+
rhs)
117+
@info string(new_var) * " ϵ" * " Ζ"
118+
else
119+
rhs = map(sol -> term(invop, sol), rhs)
120+
end
144121
end
145122

146123
lhs = simplify(lhs)
@@ -149,7 +126,7 @@ function isolate(lhs, var; warns=true, conditions=[])
149126
return rhs, conditions
150127
end
151128

152-
function attract(lhs, var; warns = true)
129+
function attract(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
153130
if n_func_occ(simplify(lhs), var) <= n_func_occ(lhs, var)
154131
lhs = simplify(lhs)
155132
end
@@ -164,7 +141,9 @@ function attract(lhs, var; warns = true)
164141
end
165142
lhs = attract_trig(lhs, var)
166143

167-
n_func_occ(lhs, var) == 1 && return isolate(lhs, var, warns = warns, conditions=conditions)
144+
if n_func_occ(lhs, var) == 1
145+
return isolate(lhs, var; warns, conditions, complex_roots, periodic_roots)
146+
end
168147

169148
lhs, sub = turn_to_poly(lhs, var)
170149

@@ -182,12 +161,12 @@ function attract(lhs, var; warns = true)
182161
new_var = collect(keys(sub))[1]
183162
new_var_val = collect(values(sub))[1]
184163

185-
roots, new_conds = isolate(lhs, new_var, warns = warns)
164+
roots, new_conds = isolate(lhs, new_var; warns = warns, complex_roots, periodic_roots)
186165
append!(conditions, new_conds)
187166
new_roots = []
188167

189168
for root in roots
190-
new_sol, new_conds = isolate(new_var_val - root, var, warns = warns)
169+
new_sol, new_conds = isolate(new_var_val - root, var; warns = warns, complex_roots, periodic_roots)
191170
append!(conditions, new_conds)
192171
push!(new_roots, new_sol)
193172
end
@@ -197,7 +176,7 @@ function attract(lhs, var; warns = true)
197176
end
198177

199178
"""
200-
ia_solve(lhs, var)
179+
ia_solve(lhs, var; kwargs...)
201180
This function attempts to solve transcendental functions by first checking
202181
the "smart" number of occurrences in the input LHS. By smart here we mean
203182
that polynomials are counted as 1 occurrence. for example `x^2 + 2x` is 1
@@ -226,6 +205,13 @@ we throw an error to tell the user that this is currently unsolvable by our cove
226205
- lhs: a Num/SymbolicUtils.BasicSymbolic
227206
- var: variable to solve for.
228207
208+
# Keyword arguments
209+
- `warns = true`: Whether to emit warnings for unsolvable expressions.
210+
- `complex_roots = true`: Whether to consider complex roots of `x ^ n ~ y`, where `n` is an integer.
211+
- `periodic_roots = true`: If `true`, isolate `f(x) ~ y` as `x ~ finv(y) + n * period` where
212+
`is_periodic(f) == true`, `finv = left_inverse(f)` and `period = fundamental_period(f)`. `n`
213+
is a new anonymous symbolic variable.
214+
229215
# Examples
230216
```jldoctest
231217
julia> solve(a*x^b + c, x)
@@ -256,20 +242,30 @@ julia> RootFinding.ia_solve(expr, x)
256242
-2 + π*2var"##230" + asin((1//2)*(-1 + RootFinding.ssqrt(-39)))
257243
-2 + π*2var"##234" + asin((1//2)*(-1 - RootFinding.ssqrt(-39)))
258244
```
245+
246+
All transcendental functions for which `left_inverse` is defined are supported.
247+
To enable `ia_solve` to handle custom transcendental functions, define an inverse or
248+
left inverse. If the function is periodic, `is_periodic` and `fundamental_period` must
249+
be defined. If the function imposes certain conditions on its input or output (for
250+
example, `log` requires that its input be positive) define `ia_conditions!`.
251+
252+
See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref),
253+
[`fundamental_period`](@ref), [`ia_conditions!`](@ref).
254+
259255
# References
260256
[^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070).
261257
"""
262-
function ia_solve(lhs, var; warns = true)
258+
function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = true)
263259
nx = n_func_occ(lhs, var)
264260
sols = []
265261
conditions = []
266262
if nx == 0
267263
warns && @warn("Var not present in given expression")
268264
return []
269265
elseif nx == 1
270-
sols, conditions = isolate(lhs, var, warns = warns)
266+
sols, conditions = isolate(lhs, var; warns = warns, complex_roots, periodic_roots)
271267
elseif nx > 1
272-
sols, conditions = attract(lhs, var, warns = warns)
268+
sols, conditions = attract(lhs, var; warns = warns, complex_roots, periodic_roots)
273269
end
274270

275271
isequal(sols, nothing) && return nothing

test/solver.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ end
411411
#@test isequal(lhs, rhs)
412412

413413
lhs = symbolic_solve(log(a*x)-b,x)[1]
414-
@test isequal(Symbolics.arguments(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))))[1], E)
414+
@test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E)
415415

416416
expr = x + 2
417417
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
@@ -431,7 +431,7 @@ end
431431
@test isapprox(eval(Symbolics.toexpr(symbolic_solve(expr, x)[1])), sqrt(2), atol=1e-6)
432432

433433
expr = 2^(x+1) + 5^(x+3)
434-
lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x)))
434+
lhs = ComplexF64.(eval.(Symbolics.toexpr.(ia_solve(expr, x))))
435435
lhs_solve = eval.(Symbolics.toexpr.(symbolic_solve(expr, x)))
436436
rhs = [(-im*Base.MathConstants.pi - log(2) + 3log(5))/(log(2) - log(5))]
437437
@test lhs[1] rhs[1]
@@ -488,6 +488,31 @@ end
488488

489489
@test all(lhs .≈ rhs)
490490
@test all(lhs_solve .≈ rhs)
491+
492+
@testset "Keyword arguments" begin
493+
expr = sec(x ^ 2 + 4x + 4) ^ 3 - 3
494+
roots = ia_solve(expr, x)
495+
@test length(roots) == 6 # 2 quadratic roots * 3 roots from cbrt(3)
496+
@test length(Symbolics.get_variables(roots[1])) == 1
497+
_n = only(Symbolics.get_variables(roots[1]))
498+
vals = substitute.(roots, (Dict(_n => 0),))
499+
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)
500+
501+
roots = ia_solve(expr, x; complex_roots = false)
502+
@test length(roots) == 2
503+
# the `n` in `θ + n * 2π`
504+
@test length(Symbolics.get_variables(roots[1])) == 1
505+
_n = only(Symbolics.get_variables(roots[1]))
506+
vals = substitute.(roots, (Dict(_n => 0),))
507+
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)
508+
509+
roots = ia_solve(expr, x; complex_roots = false, periodic_roots = false)
510+
@test length(roots) == 2
511+
@test length(Symbolics.get_variables(roots[1])) == 0
512+
@test length(Symbolics.get_variables(roots[2])) == 0
513+
vals = eval.(Symbolics.toexpr.(roots))
514+
@test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals)
515+
end
491516
end
492517

493518
@testset "Sqrt case poly" begin

0 commit comments

Comments
 (0)