|
8 | 8 | using ..SymPy
|
9 | 9 | end
|
10 | 10 |
|
11 |
| -using Symbolics: value, symbolics_to_sympy, sympy_to_symbolics |
| 11 | +using Symbolics: value, symbolics_to_sympy, sympy_to_symbolics, Differential, Num |
12 | 12 | using SymbolicUtils: iscall, operation, arguments, symtype, FnType, Symbolic, Term
|
13 | 13 | using LinearAlgebra
|
14 | 14 |
|
15 | 15 | function Symbolics.symbolics_to_sympy(expr)
|
16 | 16 | expr = value(expr)
|
17 | 17 | expr isa Symbolic || return expr
|
18 | 18 | if iscall(expr)
|
19 |
| - sop = symbolics_to_sympy(operation(expr)) |
20 |
| - sargs = map(symbolics_to_sympy, arguments(expr)) |
21 |
| - sop === (^) && length(sargs) == 2 && sargs[2] isa Number ? Base.literal_pow(^, sargs[1], Val(sargs[2])) : sop(sargs...) |
| 19 | + op = operation(expr) |
| 20 | + args = arguments(expr) |
| 21 | + |
| 22 | + if op isa Differential |
| 23 | + @assert length(args) == 1 "Differential operator must have exactly one argument." |
| 24 | + return SymPy.sympy.Derivative(symbolics_to_sympy(args[1]), symbolics_to_sympy(op.x)) |
| 25 | + end |
| 26 | + |
| 27 | + sop = symbolics_to_sympy(op) |
| 28 | + sargs = map(symbolics_to_sympy, args) |
| 29 | + return sop === (^) && length(sargs) == 2 && sargs[2] isa Number ? Base.literal_pow(^, sargs[1], Val(sargs[2])) : sop(sargs...) |
22 | 30 | else
|
23 | 31 | name = string(nameof(expr))
|
24 |
| - symtype(expr) <: FnType ? SymPy.SymFunction(name) : SymPy.Sym(name) |
| 32 | + return symtype(expr) <: FnType ? SymPy.SymFunction(name) : SymPy.Sym(name) |
25 | 33 | end
|
26 | 34 | end
|
27 | 35 |
|
@@ -89,7 +97,19 @@ function Symbolics.sympy_ode_solve(expr, func, var)
|
89 | 97 | func_sympy = symbolics_to_sympy(func)
|
90 | 98 | var_sympy = symbolics_to_sympy(var)
|
91 | 99 | sol_sympy = SymPy.dsolve(expr_sympy, func_sympy)
|
92 |
| - sympy_to_symbolics(sol_sympy, [func, var]) |
| 100 | + sol_expr = sol_sympy.rhs |
| 101 | + parsing_vars = Vector{SymbolicUtils.BasicSymbolic}() |
| 102 | + vars_in_expr = Symbolics.get_variables(value(expr)) |
| 103 | + func_val = value(func) |
| 104 | + for v in vars_in_expr |
| 105 | + if !isequal(v, func_val) |
| 106 | + push!(parsing_vars, v) |
| 107 | + end |
| 108 | + end |
| 109 | + push!(parsing_vars, value(var)) |
| 110 | + push!(parsing_vars, operation(func_val)) |
| 111 | + unwrapped_vars = unique(parsing_vars) |
| 112 | + sympy_to_symbolics(sol_expr, unwrapped_vars) |
93 | 113 | end
|
94 | 114 |
|
95 | 115 | end
|
0 commit comments