Skip to content

Handle tuples in build_function.jl #1462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,72 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUt
end
end

function _build_function(target::JuliaTarget, op::Tuple, args...;
conv = toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
checkbounds = false,
states = LazyState(),
linenumbers = true,
wrap_code = nothing,
cse = false,
nanmath = true,
kwargs...)
#work around for issue 1438
array_op = collect(op)
array_fun = _build_function(target, array_op, args...;
conv=conv,
expression=Val{true},
expression_module=expression_module,
checkbounds=checkbounds,
states=states,
linenumbers=linenumbers,
wrap_code=wrap_code,
cse=cse,
nanmath=nanmath,
kwargs...)
if array_fun isa Tuple
oop_expr, iip_expr = array_fun

@assert Meta.isexpr(oop_expr, :function)
oop_body = oop_expr.args[2]
oop_expr.args[2] = quote
res = $oop_body
tuple(res...)
end

@assert Meta.isexpr(iip_expr, :function)
iip_body = iip_expr.args[2]
iip_expr.args[2] = quote
res = $iip_body
for i in 1:length(res)
$outsym[i] = res[i]
end
tuple(res...)
end

if expression == Val{true}
return (oop_expr, iip_expr)
else
return (_build_and_inject_function(expression_module, oop_expr),
_build_and_inject_function(expression_module, iip_expr))
end
else
@assert Meta.isexpr(array_fun, :function)
body = array_fun.args[2]
array_fun.args[2] = quote
res = $body
tuple(res...)
end

if expression == Val{true}
array_fun
else
_build_and_inject_function(expression_module, array_fun)
end
end
end

function _build_and_inject_function(mod::Module, ex)
if ex.head == :function && ex.args[1].head == :tuple
ex.args[1] = Expr(:call, :($mod.$(gensym())), ex.args[1].args...)
Expand Down Expand Up @@ -217,10 +283,12 @@ end
outputidxs=nothing,
skipzeros = false,
force_SA = false,
similarto = nothing,
wrap_code = (nothing, nothing),
fillzeros = skipzeros && !(rhss isa SparseMatrixCSC),
states = LazyState(),
iip_config = (true, true),
nanmath = true,
parallel=nothing, cse = false, kwargs...)

Build function target: `JuliaTarget`
Expand Down Expand Up @@ -724,6 +792,7 @@ function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
(i, eq) ∈ enumerate(eqs)],";\n "),";")

argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:AbstractArray ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ")

ex = """
void $fname($(argstrs...)) {
$differential_equation
Expand Down Expand Up @@ -788,6 +857,7 @@ function _build_function(target::CTarget, ex::AbstractArray, args...;
return _build_function(target, hcat([row for row ∈ eachrow(ex)]...), args...;
columnmajor = true,
conv = conv,
expression = expression,
fname = fname,
lhsname = lhsname,
rhsnames = rhsnames,
Expand All @@ -804,7 +874,7 @@ function _build_function(target::CTarget, ex::AbstractArray, args...;
rhs = numbered_expr(value(ex[row, col]), varnumbercache, args...;
lhsname = lhsname,
rhsnames = rhsnames,
offset = -1) |> coperators |> string # Filter through coperators to produce valid C code in more cases
offset = 0) |> coperators |> string # Filter through coperators to produce valid C code in more cases
push!(equations, string(lhs, " = ", rhs, ";"))
end
end
Expand Down
Loading