Skip to content

work on extensions #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymPyPythonCall"
uuid = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"
authors = ["jverzani <jverzani@gmail.com> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
CommonEq = "3709ef60-1bee-4518-9f2f-acd86f176c50"
Expand All @@ -14,6 +14,14 @@ PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[weakdeps]
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
SymPyPythonCallTermInterfaceExt = "TermInterface"
SymPyPythonCallSymbolicsExt = "Symbolics"

[compat]
julia = "1.6.1"
CommonEq = "0.2"
Expand All @@ -26,6 +34,8 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1.0, 2"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[targets]
test = ["Test"]
93 changes: 93 additions & 0 deletions ext/SymPyPythonCallSymbolicsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
module SymPyPythonCallSymbolicsExt

# from https://github.com/JuliaSymbolics/Symbolics.jl/pull/957/
# by @jClugstor
using SymPyPythonCall
const PythonCall = SymPyPythonCall.PythonCall
import SymPyPythonCall.PythonCall: Py, pyisinstance, pyconvert

import Symbolics
import Symbolics: @variables
const sp = SymPyPythonCall.sympy

# rule functions
function pyconvert_rule_sympy_symbol(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Symbol)
return PythonCall.pyconvert_unconverted()
end
name = PythonCall.pyconvert(Symbol,x.name)
return PythonCall.pyconvert_return(Symbolics.variable(name))
end

function pyconvert_rule_sympy_pow(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Pow)
return PythonCall.pyconvert_unconverted()
end
expbase = pyconvert(Symbolics.Num,x.base)
exp = pyconvert(Symbolics.Num,x.exp)
return PythonCall.pyconvert_return(expbase^exp)
end

function pyconvert_rule_sympy_mul(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Mul)
return PythonCall.pyconvert_unconverted()
end
mult = reduce(*,PythonCall.pyconvert.(Symbolics.Num,x.args))
return PythonCall.pyconvert_return(mult)
end

function pyconvert_rule_sympy_add(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Add)
return PythonCall.pyconvert_unconverted()
end
sum = reduce(+, PythonCall.pyconvert.(Symbolics.Num,x.args))
return PythonCall.pyconvert_return(sum)
end

function pyconvert_rule_sympy_equality(::Type{Symbolics.Equation}, x::Py)
if !pyisinstance(x,sp.Equality)
return PythonCall.pyconvert_unconverted()
end
rhs = pyconvert(Symbolics.Num,x.rhs)
lhs = pyconvert(Symbolics.Num,x.lhs)
return PythonCall.pyconvert_return(rhs ~ lhs)
end

function pyconvert_rule_sympy_derivative(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Derivative)
return PythonCall.pyconvert_unconverted()
end
variables = pyconvert.(Symbolics.Num,x.variables)
derivatives = prod(var -> Differential(var), variables)
expr = pyconvert(Symbolics.Num, x.expr)
return PythonCall.pyconvert_return(derivatives(expr))
end

function pyconvert_rule_sympy_function(::Type{Symbolics.Num}, x::Py)
if !pyisinstance(x,sp.Function)
return PythonCall.pyconvert_unconverted()
end
name = pyconvert(Symbol,x.name)
args = pyconvert.(Symbolics.Num,x.args)
func = @variables $name(..)
return PythonCall.pyconvert_return(first(func)(args...))
end

# added rules
PythonCall.pyconvert_add_rule("sympy.core.power:Pow", Symbolics.Num, pyconvert_rule_sympy_pow)

PythonCall.pyconvert_add_rule("sympy.core.symbol:Symbol", Symbolics.Num, pyconvert_rule_sympy_symbol)

PythonCall.pyconvert_add_rule("sympy.core.mul:Mul", Symbolics.Num, pyconvert_rule_sympy_mul)

PythonCall.pyconvert_add_rule("sympy.core.add:Add", Symbolics.Num, pyconvert_rule_sympy_add)

PythonCall.pyconvert_add_rule("sympy.core.relational:Equality", Symbolics.Equation, pyconvert_rule_sympy_equality)

PythonCall.pyconvert_add_rule("sympy.core.function:Derivative",Symbolics.Num, pyconvert_rule_sympy_derivative)

PythonCall.pyconvert_add_rule("sympy.core.function:Function",Symbolics.Num, pyconvert_rule_sympy_function)



end
52 changes: 52 additions & 0 deletions ext/SymPyPythonCallTermInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
module SymPyPythonCallTermInterfaceExt

import SymPyPythonCall
import TermInterface

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TermInterface was effectively deprecated. Doing this directly on the terms in SymbolicUtils would be more helpful.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that detail!


#==
Check if x represents an expression tree. If returns true, it will be assumed that operation(::T) and arguments(::T) methods are defined. Definining these three should allow use of SymbolicUtils.simplify on custom types. Optionally symtype(x) can be defined to return the expected type of the symbolic expression.
==#
function TermInterface.istree(x::SymPyPythonCall.SymbolicObject)
!(convert(Bool, x.is_Atom))
end

#==
f x is a term as defined by istree(x), exprhead(x) must return a symbol, corresponding to the head of the Expr most similar to the term x. If x represents a function call, for example, the exprhead is :call. If x represents an indexing operation, such as arr[i], then exprhead is :ref. Note that exprhead is different from operation and both functions should be defined correctly in order to let other packages provide code generation and pattern matching features.
==#
function TermInterface.exprhead(x::SymPyPythonCall.SymbolicObject)
:call # this is not right
end

#==
Returns the head (a function object) performed by an expression tree. Called only if istree(::T) is true. Part of the API required for simplify to work. Other required methods are arguments and istree
==#
function TermInterface.operation(x::SymPyPythonCall.SymbolicObject)
@assert TermInterface.istree(x)
nm = Symbol(SymPyPythonCall.Introspection.funcname(x))

λ = get(SymPyPythonCall.Introspection.funcname2function, nm, nothing)
if isnothing(λ)
return getfield(Main, nm)
else
return λ
end
end


#==
Returns the arguments (a Vector) for an expression tree. Called only if istree(x) is true. Part of the API required for simplify to work. Other required methods are operation and istree
==#
function TermInterface.arguments(x::SymPyPythonCall.SymbolicObject)
collect(SymPyPythonCall.Introspection.args(x))
end

#==
Construct a new term with the operation f and arguments args, the term should be similar to t in type. if t is a SymbolicUtils.Term object a new Term is created with the same symtype as t. If not, the result is computed as f(args...). Defining this method for your term type will reduce any performance loss in performing f(args...) (esp. the splatting, and redundant type computation). T is the symtype of the output term. You can use SymbolicUtils.promote_symtype to infer this type. The exprhead keyword argument is useful when creating Exprs.
==#
function TermInterface.similarterm(t::SymPyPythonCall.SymbolicObject, f, args, symtype=nothing;
metadata=nothing, exprhead=TermInterface.exprhead(t))
f(args...) # default
end


end
32 changes: 32 additions & 0 deletions src/introspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,36 @@ classname(x::T) where {T <: Union{Sym, Py}} = (cls = class(x); isnothing(cls) ?
# Dict(u=>v for (u,v) in inspect.getmembers(x))
#end

## Map to get function object from type information
const funcname2function = (
Add = +,
Sub = -,
Mul = *,
Div = /,
Pow = ^,
re = real,
im = imag,
Abs = abs,
Min = min,
Max = max,
Poly = identity,
Piecewise = error, # replace
Order = (as...) -> 0,
And = (as...) -> all(as),
Or = (as...) -> any(as),
Less = <,
LessThan = <=,
StrictLessThan = <,
Equal = ==,
Equality = ==,
Unequality = !==,
StrictGreaterThan = >,
GreaterThan = >=,
Greater = >,
conjugate = conj,
atan2 = atan,
TupleArg = tuple,
Heaviside = (a...) -> (a[1] < 0 ? 0 : (a[1] > 0 ? 1 : (length(a) > 1 ? a[2] : NaN))),
)

end