Skip to content

Commit 00b7982

Browse files
committed
work on extensions
1 parent 7dec2c0 commit 00b7982

File tree

4 files changed

+188
-1
lines changed

4 files changed

+188
-1
lines changed

Project.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SymPyPythonCall"
22
uuid = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
33
authors = ["jverzani <jverzani@gmail.com> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
CommonEq = "3709ef60-1bee-4518-9f2f-acd86f176c50"
@@ -14,6 +14,14 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
1414
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1515
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1616

17+
[weakdeps]
18+
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
19+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
20+
21+
[extensions]
22+
SymPyPythonCallTermInterfaceExt = "TermInterface"
23+
SymPyPythonCallSymbolicsExt = "Symbolics"
24+
1725
[compat]
1826
julia = "1.6.1"
1927
CommonEq = "0.2"
@@ -26,6 +34,8 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2"
2634

2735
[extras]
2836
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
37+
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
38+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
2939

3040
[targets]
3141
test = ["Test"]

ext/SymPyPythonCallSymbolicsExt.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
module SymPyPythonCallSymbolicsExt
2+
3+
# from https://github.com/JuliaSymbolics/Symbolics.jl/pull/957/
4+
# by @jClugstor
5+
using SymPyPythonCall
6+
const PythonCall = SymPyPythonCall.PythonCall
7+
import SymPyPythonCall.PythonCall: Py, pyisinstance, pyconvert
8+
9+
import Symbolics
10+
import Symbolics: @variables
11+
const sp = SymPyPythonCall.sympy
12+
13+
# rule functions
14+
function pyconvert_rule_sympy_symbol(::Type{Symbolics.Num}, x::Py)
15+
if !pyisinstance(x,sp.Symbol)
16+
return PythonCall.pyconvert_unconverted()
17+
end
18+
name = PythonCall.pyconvert(Symbol,x.name)
19+
return PythonCall.pyconvert_return(Symbolics.variable(name))
20+
end
21+
22+
function pyconvert_rule_sympy_pow(::Type{Symbolics.Num}, x::Py)
23+
if !pyisinstance(x,sp.Pow)
24+
return PythonCall.pyconvert_unconverted()
25+
end
26+
expbase = pyconvert(Symbolics.Num,x.base)
27+
exp = pyconvert(Symbolics.Num,x.exp)
28+
return PythonCall.pyconvert_return(expbase^exp)
29+
end
30+
31+
function pyconvert_rule_sympy_mul(::Type{Symbolics.Num}, x::Py)
32+
if !pyisinstance(x,sp.Mul)
33+
return PythonCall.pyconvert_unconverted()
34+
end
35+
mult = reduce(*,PythonCall.pyconvert.(Symbolics.Num,x.args))
36+
return PythonCall.pyconvert_return(mult)
37+
end
38+
39+
function pyconvert_rule_sympy_add(::Type{Symbolics.Num}, x::Py)
40+
if !pyisinstance(x,sp.Add)
41+
return PythonCall.pyconvert_unconverted()
42+
end
43+
sum = reduce(+, PythonCall.pyconvert.(Symbolics.Num,x.args))
44+
return PythonCall.pyconvert_return(sum)
45+
end
46+
47+
function pyconvert_rule_sympy_equality(::Type{Symbolics.Equation}, x::Py)
48+
if !pyisinstance(x,sp.Equality)
49+
return PythonCall.pyconvert_unconverted()
50+
end
51+
rhs = pyconvert(Symbolics.Num,x.rhs)
52+
lhs = pyconvert(Symbolics.Num,x.lhs)
53+
return PythonCall.pyconvert_return(rhs ~ lhs)
54+
end
55+
56+
function pyconvert_rule_sympy_derivative(::Type{Symbolics.Num}, x::Py)
57+
if !pyisinstance(x,sp.Derivative)
58+
return PythonCall.pyconvert_unconverted()
59+
end
60+
variables = pyconvert.(Symbolics.Num,x.variables)
61+
derivatives = prod(var -> Differential(var), variables)
62+
expr = pyconvert(Symbolics.Num, x.expr)
63+
return PythonCall.pyconvert_return(derivatives(expr))
64+
end
65+
66+
function pyconvert_rule_sympy_function(::Type{Symbolics.Num}, x::Py)
67+
if !pyisinstance(x,sp.Function)
68+
return PythonCall.pyconvert_unconverted()
69+
end
70+
name = pyconvert(Symbol,x.name)
71+
args = pyconvert.(Symbolics.Num,x.args)
72+
func = @variables $name(..)
73+
return PythonCall.pyconvert_return(first(func)(args...))
74+
end
75+
76+
# added rules
77+
PythonCall.pyconvert_add_rule("sympy.core.power:Pow", Symbolics.Num, pyconvert_rule_sympy_pow)
78+
79+
PythonCall.pyconvert_add_rule("sympy.core.symbol:Symbol", Symbolics.Num, pyconvert_rule_sympy_symbol)
80+
81+
PythonCall.pyconvert_add_rule("sympy.core.mul:Mul", Symbolics.Num, pyconvert_rule_sympy_mul)
82+
83+
PythonCall.pyconvert_add_rule("sympy.core.add:Add", Symbolics.Num, pyconvert_rule_sympy_add)
84+
85+
PythonCall.pyconvert_add_rule("sympy.core.relational:Equality", Symbolics.Equation, pyconvert_rule_sympy_equality)
86+
87+
PythonCall.pyconvert_add_rule("sympy.core.function:Derivative",Symbolics.Num, pyconvert_rule_sympy_derivative)
88+
89+
PythonCall.pyconvert_add_rule("sympy.core.function:Function",Symbolics.Num, pyconvert_rule_sympy_function)
90+
91+
92+
93+
end
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module SymPyPythonCallTermInterfaceExt
2+
3+
import SymPyPythonCall
4+
import TermInterface
5+
6+
#==
7+
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
8+
==#
9+
function TermInterface.istree(x::SymPyPythonCall.SymbolicObject)
10+
!(convert(Bool, x.is_Atom))
11+
end
12+
13+
#==
14+
f x is a term as defined by istree(x), exprhead(x) must return a symbol, corresponding to the head of the Expr most similar to the term x. If x represents a function call, for example, the exprhead is :call. If x represents an indexing operation, such as arr[i], then exprhead is :ref. Note that exprhead is different from operation and both functions should be defined correctly in order to let other packages provide code generation and pattern matching features.
15+
==#
16+
function TermInterface.exprhead(x::SymPyPythonCall.SymbolicObject)
17+
:call # this is not right
18+
end
19+
20+
#==
21+
Returns the head (a function object) performed by an expression tree. Called only if istree(::T) is true. Part of the API required for simplify to work. Other required methods are arguments and istree
22+
==#
23+
function TermInterface.operation(x::SymPyPythonCall.SymbolicObject)
24+
@assert TermInterface.istree(x)
25+
nm = Symbol(SymPyPythonCall.Introspection.funcname(x))
26+
27+
λ = get(SymPyPythonCall.Introspection.funcname2function, nm, nothing)
28+
if isnothing(λ)
29+
return getfield(Main, nm)
30+
else
31+
return λ
32+
end
33+
end
34+
35+
36+
#==
37+
Returns the arguments (a Vector) for an expression tree. Called only if istree(x) is true. Part of the API required for simplify to work. Other required methods are operation and istree
38+
==#
39+
function TermInterface.arguments(x::SymPyPythonCall.SymbolicObject)
40+
collect(SymPyPythonCall.Introspection.args(x))
41+
end
42+
43+
#==
44+
Construct a new term with the operation f and arguments args, the term should be similar to t in type. if t is a SymbolicUtils.Term object a new Term is created with the same symtype as t. If not, the result is computed as f(args...). Defining this method for your term type will reduce any performance loss in performing f(args...) (esp. the splatting, and redundant type computation). T is the symtype of the output term. You can use SymbolicUtils.promote_symtype to infer this type. The exprhead keyword argument is useful when creating Exprs.
45+
==#
46+
function TermInterface.similarterm(t::SymPyPythonCall.SymbolicObject, f, args, symtype=nothing;
47+
metadata=nothing, exprhead=TermInterface.exprhead(t))
48+
f(args...) # default
49+
end
50+
51+
52+
end

src/introspection.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,36 @@ classname(x::T) where {T <: Union{Sym, Py}} = (cls = class(x); isnothing(cls) ?
5151
# Dict(u=>v for (u,v) in inspect.getmembers(x))
5252
#end
5353

54+
## Map to get function object from type information
55+
const funcname2function = (
56+
Add = +,
57+
Sub = -,
58+
Mul = *,
59+
Div = /,
60+
Pow = ^,
61+
re = real,
62+
im = imag,
63+
Abs = abs,
64+
Min = min,
65+
Max = max,
66+
Poly = identity,
67+
Piecewise = error, # replace
68+
Order = (as...) -> 0,
69+
And = (as...) -> all(as),
70+
Or = (as...) -> any(as),
71+
Less = <,
72+
LessThan = <=,
73+
StrictLessThan = <,
74+
Equal = ==,
75+
Equality = ==,
76+
Unequality = !==,
77+
StrictGreaterThan = >,
78+
GreaterThan = >=,
79+
Greater = >,
80+
conjugate = conj,
81+
atan2 = atan,
82+
TupleArg = tuple,
83+
Heaviside = (a...) -> (a[1] < 0 ? 0 : (a[1] > 0 ? 1 : (length(a) > 1 ? a[2] : NaN))),
84+
)
85+
5486
end

0 commit comments

Comments
 (0)