Skip to content

Commit bcf5b3d

Browse files
feat: make ia_solve modular
1 parent 5aa7cab commit bcf5b3d

File tree

2 files changed

+111
-50
lines changed

2 files changed

+111
-50
lines changed

src/solver/ia_helpers.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,85 @@ function find_logandexpon(arg, var, oper, poly_index)
140140
!isequal(oper_term, 0) && !isequal(constant_term, 0) && return true
141141
return false
142142
end
143+
144+
"""
145+
ia_conditions!(f, lhs, rhs::Vector{Any}, conditions::Vector{Tuple})
146+
147+
If `f` is a left-invertible function, `lhs` and `rhs[i]` are univariate functions and
148+
`f(lhs) ~ rhs[i]` for all `i in eachindex(rhss)`, push to `conditions` all the relevant
149+
conditions on `lhs` or `rhs[i]`. Each condition is of the form `(sym, op)` where `sym`
150+
is an expression involving `lhs` and/or `rhs[i]` and `op` is a binary relational operator.
151+
The condition `op(sym, 0)` is then required to be true for the equation `f(lhs) ~ rhs[i]`
152+
to be valid.
153+
154+
For example, if `f = log`, `lhs = x` and `rhss = [y, z]` then the condition `x > 0` must
155+
be true. Thus, `(lhs, >)` is pushed to `conditions`. Similarly, if `f = sqrt`, `rhs[i] >= 0`
156+
must be true for all `i`, and so `(y, >=)` and `(z, >=)` will be appended to `conditions`.
157+
"""
158+
function ia_conditions!(args...; kwargs...) end
159+
160+
for fn in [log, log2, log10, NaNMath.log, NaNMath.log2, NaNMath.log10, slog]
161+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
162+
push!(conditions, (lhs, >))
163+
end
164+
end
165+
166+
for fn in [log1p, NaNMath.log1p]
167+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
168+
push!(conditions, (lhs - 1, >))
169+
end
170+
end
171+
172+
for fn in [sqrt, NaNMath.sqrt, ssqrt]
173+
@eval function ia_conditions!(::typeof($fn), lhs, rhs, conditions)
174+
for r in rhs
175+
push!(conditions, (r, >=))
176+
end
177+
end
178+
end
179+
180+
"""
181+
is_periodic(f)
182+
183+
Return `true` if `f` is a single-input single-output periodic function. Return `false` by
184+
default. If `is_periodic(f) == true`, then `fundamental_period(f)` must also be defined.
185+
186+
See also: [`fundamental_period`](@ref)
187+
"""
188+
is_periodic(f) = false
189+
190+
for fn in [
191+
sin, cos, tan, csc, sec, cot, NaNMath.sin, NaNMath.cos, NaNMath.tan, sind, cosd, tand,
192+
cscd, secd, cotd, cospi
193+
]
194+
@eval is_periodic(::typeof($fn)) = true
195+
end
196+
197+
"""
198+
fundamental_period(f)
199+
200+
Return the fundamental period of periodic function `f`. Must only be called if
201+
`is_periodic(f) == true`.
202+
203+
see also: [`is_periodic`](@ref)
204+
"""
205+
function fundamental_period end
206+
207+
for fn in [sin, cos, csc, sec, NaNMath.sin, NaNMath.cos]
208+
@eval fundamental_period(::typeof($fn)) = 2pi
209+
end
210+
211+
for fn in [sind, cosd, cscd, secd]
212+
@eval fundamental_period(::typeof($fn)) = 360.0
213+
end
214+
215+
fundamental_period(::typeof(cospi)) = 2.0
216+
217+
for fn in [tand, cotd]
218+
@eval fundamental_period(::typeof($fn)) = 180.0
219+
end
220+
221+
for fn in [tan, cot, NaNMath.tan]
222+
# `1pi isa Float64` whereas `pi isa Irrational{:π}`
223+
@eval fundamental_period(::typeof($fn)) = 1pi
224+
end

src/solver/ia_main.jl

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Symbolics
22

3+
const SAFE_ALTERNATIVES = Dict(log => slog, sqrt => ssqrt, cbrt => scbrt)
4+
35
function isolate(lhs, var; warns=true, conditions=[])
46
rhs = Vector{Any}([0])
57
original_lhs = deepcopy(lhs)
@@ -90,57 +92,24 @@ function isolate(lhs, var; warns=true, conditions=[])
9092
lhs = args[2]
9193
rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs)
9294
end
93-
94-
elseif oper === (log) || oper === (slog)
95-
lhs = args[1]
96-
rhs = map(sol -> term(^, Base.MathConstants.e, sol), rhs)
97-
push!(conditions, (args[1], >))
98-
99-
elseif oper === (log2)
100-
lhs = args[1]
101-
rhs = map(sol -> term(^, 2, sol), rhs)
102-
push!(conditions, (args[1], >))
103-
104-
elseif oper === (log10)
105-
lhs = args[1]
106-
rhs = map(sol -> term(^, 10, sol), rhs)
107-
push!(conditions, (args[1], >))
108-
109-
elseif oper === (sqrt)
110-
lhs = args[1]
111-
append!(conditions, [(r, >=) for r in rhs])
112-
rhs = map(sol -> term(^, sol, 2), rhs)
113-
114-
elseif oper === (cbrt)
115-
lhs = args[1]
116-
rhs = map(sol -> term(^, sol, 3), rhs)
117-
118-
elseif oper === (sin) || oper === (cos) || oper === (tan)
119-
rev_oper = Dict(sin => asin, cos => acos, tan => atan)
120-
lhs = args[1]
121-
# make this global somehow so the user doesnt need to declare it on his own
122-
new_var = gensym()
123-
new_var = (@variables $new_var)[1]
124-
rhs = map(
125-
sol -> term(rev_oper[oper], sol) +
126-
term(*, Base.MathConstants.pi, new_var),
127-
rhs)
128-
@info string(new_var) * " ϵ" * " Ζ"
129-
130-
elseif oper === (asin)
131-
lhs = args[1]
132-
rhs = map(sol -> term(sin, sol), rhs)
133-
134-
elseif oper === (acos)
135-
lhs = args[1]
136-
rhs = map(sol -> term(cos, sol), rhs)
137-
138-
elseif oper === (atan)
95+
elseif has_left_inverse(oper)
13996
lhs = args[1]
140-
rhs = map(sol -> term(tan, sol), rhs)
141-
elseif oper === (exp)
142-
lhs = args[1]
143-
rhs = map(sol -> term(slog, sol), rhs)
97+
ia_conditions!(oper, lhs, rhs, conditions)
98+
invop = left_inverse(oper)
99+
invop = get(SAFE_ALTERNATIVES, invop, invop)
100+
if is_periodic(oper)
101+
# make this global somehow so the user doesnt need to declare it on his own
102+
new_var = gensym()
103+
new_var = (@variables $new_var)[1]
104+
period = fundamental_period(oper)
105+
rhs = map(
106+
sol -> term(invop, sol) +
107+
term(*, period, new_var),
108+
rhs)
109+
@info string(new_var) * " ϵ" * " Ζ"
110+
else
111+
rhs = map(sol -> term(invop, sol), rhs)
112+
end
144113
end
145114

146115
lhs = simplify(lhs)
@@ -256,6 +225,16 @@ julia> RootFinding.ia_solve(expr, x)
256225
-2 + π*2var"##230" + asin((1//2)*(-1 + RootFinding.ssqrt(-39)))
257226
-2 + π*2var"##234" + asin((1//2)*(-1 - RootFinding.ssqrt(-39)))
258227
```
228+
229+
All transcendental functions for which `left_inverse` is defined are supported.
230+
To enable `ia_solve` to handle custom transcendental functions, define an inverse or
231+
left inverse. If the function is periodic, `is_periodic` and `fundamental_period` must
232+
be defined. If the function imposes certain conditions on its input or output (for
233+
example, `log` requires that its input be positive) define `ia_conditions!`.
234+
235+
See also: [`left_inverse`](@ref), [`inverse`](@ref), [`is_periodic`](@ref),
236+
[`fundamental_period`](@ref), [`ia_conditions!`](@ref).
237+
259238
# References
260239
[^1]: [R. W. Hamming, Coding and Information Theory, ScienceDirect, 1980](https://www.sciencedirect.com/science/article/pii/S0747717189800070).
261240
"""

0 commit comments

Comments
 (0)