Skip to content

Commit 48e3519

Browse files
Merge pull request #1578 from AJ0070/master
Fixed ODE Solver(Test case: 8)
2 parents 45a197e + 3babf99 commit 48e3519

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

ext/SymbolicsSymPyExt.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,28 @@ else
88
using ..SymPy
99
end
1010

11-
using Symbolics: value, symbolics_to_sympy, sympy_to_symbolics
11+
using Symbolics: value, symbolics_to_sympy, sympy_to_symbolics, Differential, Num
1212
using SymbolicUtils: iscall, operation, arguments, symtype, FnType, Symbolic, Term
1313
using LinearAlgebra
1414

1515
function Symbolics.symbolics_to_sympy(expr)
1616
expr = value(expr)
1717
expr isa Symbolic || return expr
1818
if iscall(expr)
19-
sop = symbolics_to_sympy(operation(expr))
20-
sargs = map(symbolics_to_sympy, arguments(expr))
21-
sop === (^) && length(sargs) == 2 && sargs[2] isa Number ? Base.literal_pow(^, sargs[1], Val(sargs[2])) : sop(sargs...)
19+
op = operation(expr)
20+
args = arguments(expr)
21+
22+
if op isa Differential
23+
@assert length(args) == 1 "Differential operator must have exactly one argument."
24+
return SymPy.sympy.Derivative(symbolics_to_sympy(args[1]), symbolics_to_sympy(op.x))
25+
end
26+
27+
sop = symbolics_to_sympy(op)
28+
sargs = map(symbolics_to_sympy, args)
29+
return sop === (^) && length(sargs) == 2 && sargs[2] isa Number ? Base.literal_pow(^, sargs[1], Val(sargs[2])) : sop(sargs...)
2230
else
2331
name = string(nameof(expr))
24-
symtype(expr) <: FnType ? SymPy.SymFunction(name) : SymPy.Sym(name)
32+
return symtype(expr) <: FnType ? SymPy.SymFunction(name) : SymPy.Sym(name)
2533
end
2634
end
2735

@@ -89,7 +97,19 @@ function Symbolics.sympy_ode_solve(expr, func, var)
8997
func_sympy = symbolics_to_sympy(func)
9098
var_sympy = symbolics_to_sympy(var)
9199
sol_sympy = SymPy.dsolve(expr_sympy, func_sympy)
92-
sympy_to_symbolics(sol_sympy, [func, var])
100+
sol_expr = sol_sympy.rhs
101+
parsing_vars = Vector{SymbolicUtils.BasicSymbolic}()
102+
vars_in_expr = Symbolics.get_variables(value(expr))
103+
func_val = value(func)
104+
for v in vars_in_expr
105+
if !isequal(v, func_val)
106+
push!(parsing_vars, v)
107+
end
108+
end
109+
push!(parsing_vars, value(var))
110+
push!(parsing_vars, operation(func_val))
111+
unwrapped_vars = unique(parsing_vars)
112+
sympy_to_symbolics(sol_expr, unwrapped_vars)
93113
end
94114

95115
end

test/sympy.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ result = sympy_simplify(expr)
6565
@test isequal(Symbolics.simplify(result), 3x^2)
6666

6767
# Test 8: ODE solver
68-
# D = Differential(x)
69-
# ode = D(f) - 2*f
70-
# sol_ode = sympy_ode_solve(ode, f, x)
71-
# @test isequal(sol_ode, Symbolics.parse("C1*exp(2*x)", Dict("f"=>f, "x"=>x)))
68+
D = Symbolics.Differential(x)
69+
@variables C1
70+
ode = D(f) - 2*f
71+
sol_ode = sympy_ode_solve(ode, f, x)
72+
sol_vars = Symbolics.get_variables(sol_ode)
73+
const_sym = only(filter(v -> startswith(string(Symbolics.nameof(v)), "C"), sol_vars))
74+
expected_sol = C1 * exp(2 * x)
75+
canonical_sol_ode = Symbolics.substitute(sol_ode, Dict(const_sym => C1))
76+
@test isequal(canonical_sol_ode, expected_sol)

0 commit comments

Comments
 (0)