Skip to content

Commit d08d164

Browse files
authored
Merge pull request #631 from JuliaSymbolics/fix-array-function-type-inference
Fix `maketerm` handling of `BasicSymbolic{Array}`
2 parents 77951c9 + 8b52a80 commit d08d164

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/types.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -563,17 +563,18 @@ function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata)
563563
# Where the result would have a symtype of Bool.
564564
# Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609
565565
# TODO this should be optimized.
566-
new_st = if pst === Bool
567-
pst
568-
elseif pst === Any || (st === Number && pst <: st)
566+
new_st = if st <: AbstractArray
569567
st
570-
else
571-
pst
572-
end
568+
elseif pst === Bool
569+
pst
570+
elseif pst === Any || (st === Number && pst <: st)
571+
st
572+
else
573+
pst
574+
end
573575
basicsymbolic(head, args, new_st, metadata)
574576
end
575577

576-
577578
function basicsymbolic(f, args, stype, metadata)
578579
if f isa Symbol
579580
error("$f must not be a Symbol")

test/basics.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,11 @@ end
249249
@syms x::Int y::Int
250250
new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (+), [x, y], nothing)
251251
@test symtype(new_expr) == Int64
252+
253+
# Check that the Array type does not get changed to AbstractArray
254+
new_expr = SymbolicUtils.maketerm(
255+
SymbolicUtils.BasicSymbolic{Vector{Float64}}, sin, [1.0, 2.0], nothing)
256+
@test symtype(new_expr) == Vector{Float64}
252257
end
253258

254259
toterm(t) = Term{symtype(t)}(operation(t), arguments(t))

0 commit comments

Comments
 (0)