From b1f2e5d8aaba6c924c8a439a25c2a2808876dc2b Mon Sep 17 00:00:00 2001 From: Shashank Kirtania Date: Sun, 2 Mar 2025 10:59:38 +0530 Subject: [PATCH] added workaround to handle tupples in build_function.jl --- src/build_function.jl | 72 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/src/build_function.jl b/src/build_function.jl index 0a6c152dc..892c0bbd3 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -183,6 +183,72 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUt end end +function _build_function(target::JuliaTarget, op::Tuple, args...; + conv = toexpr, + expression = Val{true}, + expression_module = @__MODULE__(), + checkbounds = false, + states = LazyState(), + linenumbers = true, + wrap_code = nothing, + cse = false, + nanmath = true, + kwargs...) + #work around for issue 1438 + array_op = collect(op) + array_fun = _build_function(target, array_op, args...; + conv=conv, + expression=Val{true}, + expression_module=expression_module, + checkbounds=checkbounds, + states=states, + linenumbers=linenumbers, + wrap_code=wrap_code, + cse=cse, + nanmath=nanmath, + kwargs...) + if array_fun isa Tuple + oop_expr, iip_expr = array_fun + + @assert Meta.isexpr(oop_expr, :function) + oop_body = oop_expr.args[2] + oop_expr.args[2] = quote + res = $oop_body + tuple(res...) + end + + @assert Meta.isexpr(iip_expr, :function) + iip_body = iip_expr.args[2] + iip_expr.args[2] = quote + res = $iip_body + for i in 1:length(res) + $outsym[i] = res[i] + end + tuple(res...) + end + + if expression == Val{true} + return (oop_expr, iip_expr) + else + return (_build_and_inject_function(expression_module, oop_expr), + _build_and_inject_function(expression_module, iip_expr)) + end + else + @assert Meta.isexpr(array_fun, :function) + body = array_fun.args[2] + array_fun.args[2] = quote + res = $body + tuple(res...) + end + + if expression == Val{true} + array_fun + else + _build_and_inject_function(expression_module, array_fun) + end + end +end + function _build_and_inject_function(mod::Module, ex) if ex.head == :function && ex.args[1].head == :tuple ex.args[1] = Expr(:call, :($mod.$(gensym())), ex.args[1].args...) @@ -217,10 +283,12 @@ end outputidxs=nothing, skipzeros = false, force_SA = false, + similarto = nothing, wrap_code = (nothing, nothing), fillzeros = skipzeros && !(rhss isa SparseMatrixCSC), states = LazyState(), iip_config = (true, true), + nanmath = true, parallel=nothing, cse = false, kwargs...) Build function target: `JuliaTarget` @@ -724,6 +792,7 @@ function _build_function(target::CTarget, eqs::Array{<:Equation}, args...; (i, eq) ∈ enumerate(eqs)],";\n "),";") argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:AbstractArray ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ") + ex = """ void $fname($(argstrs...)) { $differential_equation @@ -788,6 +857,7 @@ function _build_function(target::CTarget, ex::AbstractArray, args...; return _build_function(target, hcat([row for row ∈ eachrow(ex)]...), args...; columnmajor = true, conv = conv, + expression = expression, fname = fname, lhsname = lhsname, rhsnames = rhsnames, @@ -804,7 +874,7 @@ function _build_function(target::CTarget, ex::AbstractArray, args...; rhs = numbered_expr(value(ex[row, col]), varnumbercache, args...; lhsname = lhsname, rhsnames = rhsnames, - offset = -1) |> coperators |> string # Filter through coperators to produce valid C code in more cases + offset = 0) |> coperators |> string # Filter through coperators to produce valid C code in more cases push!(equations, string(lhs, " = ", rhs, ";")) end end