Skip to content

Commit d16bc03

Browse files
Merge pull request #1345 from n0rbed/tests
Tests for solve_interms_ofvar and bug fixes
2 parents 705b61b + 62d8d0e commit d16bc03

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

src/solver/ia_rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
5050
coeffs, constant = polynomial_coeffs(eq, [s])
5151
eqs = wrap.(collect(values(coeffs)))
5252

53-
solve_multivar(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
53+
symbolic_solve(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
5454
end
5555

5656
# an attempt at using ia_solve recursively.

src/solver/main.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ function solve_univar(expression, x; dropmultiplicity=true)
286286
factors_subbed = map(factor -> ssubs(factor, subs), factors)
287287
arr_roots = []
288288

289-
if degree < 5 && length(factors) == 1
289+
if degree < 5 && isequal(factors_subbed[1], wrap(expression))
290290
arr_roots = get_roots(expression, x)
291291

292292
# multiplicities (repeated roots)
@@ -296,10 +296,8 @@ function solve_univar(expression, x; dropmultiplicity=true)
296296
append!(arr_roots, og_arr_roots)
297297
end
298298
end
299-
end
300-
301-
if length(factors) != 1
302-
for i in eachindex(factors_subbed)
299+
elseif length(factors) > 1 || (length(factors) == 1 && !isequal(factors_subbed[1], wrap(expression)))
300+
for i in eachindex(factors_subbed)
303301
if !any(isequal(x, var) for var in get_variables(factors[i]))
304302
continue
305303
end

test/solver.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,24 @@ function check_approx(arr1, arr2)
5454
return true
5555
end
5656

57-
@variables x y z a b c d e
57+
@variables x y z a b c d e s
58+
59+
@testset "Solving in terms of a constant var" begin
60+
eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d))
61+
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d])
62+
known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d])
63+
@test check_approx(calcd_roots, known_roots)
64+
65+
eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s
66+
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
67+
known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b])
68+
@test check_approx(calcd_roots, known_roots)
69+
70+
eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3
71+
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
72+
known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x])
73+
@test check_approx(calcd_roots, known_roots)
74+
end
5875

5976
@testset "Invalid input" begin
6077
@test_throws AssertionError symbolic_solve(x, x^2)

0 commit comments

Comments
 (0)