diff --git a/src/array-lib.jl b/src/array-lib.jl index 51c514ea4..2f606e7d8 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -290,21 +290,45 @@ end @wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...) function _map(f, x, xs...) + return ArrayOp( + SymbolicUtils._promote_symtype(_map, (x,xs)), + (idx...,), + expr, + +, + Term{Any}(map, [f, x, xs...]) + ) +end + +function SymbolicUtils._promote_symtype(::typeof(_map), args) + f, x, xs... = args + N = ndims(x) idx = makesubscripts(N) - + expr = f(map(a->a[idx...], [x, xs...])...) Atype = propagate_atype(map, f, x, xs...) - ArrayOp(Atype{symtype(expr), N}, - (idx...,), - expr, - +, - Term{Any}(map, [f, x, xs...])) + + return Atype{symtype(expr), N} end @inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...) +function SymbolicUtils._promote_symtype(::typeof(_mapreduce), args) + @assert length(args) == 5 + f, op, x, dims, kw = args + + N = ndims(x) + idx = makesubscripts(N) + expr = f(x[idx...]) + T = symtype(op(expr, expr)) + if dims === (:) + return T + end + Atype = propagate_atype(_mapreduce, f, op, x, dims, (kw...,)) + return Atype{T, N} +end + function scalarize_op(::typeof(_mapreduce), t) f,g,x,dims,kw = arguments(t) # we wrap and unwrap to make things work smoothly. @@ -313,20 +337,19 @@ function scalarize_op(::typeof(_mapreduce), t) end @wrapped function Base.mapreduce(f, g, x::AbstractArray; dims=:, kw...) - idx = makesubscripts(ndims(x)) - out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] - expr = f(x[idx...]) - T = symtype(g(expr, expr)) + Stype = SymbolicUtils._promote_symtype(_mapreduce, (f,g,x,dims,kw)) if dims === (:) - return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)]) + return Term{Stype}(_mapreduce, [f, g, x, dims, (kw...,)]) end - - Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,)) - ArrayOp(Atype{T, ndims(x)}, - (out_idx...,), - expr, - g, - Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)])) + idx = makesubscripts(ndims(x)) + out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] + return ArrayOp( + Stype, + (out_idx...,), + expr, + g, + Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)]) + ) end for (ff, opts) in [sum => (identity, +, false), diff --git a/test/arrays.jl b/test/arrays.jl index 5eef9b9d6..40e797669 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -1,6 +1,6 @@ using Symbolics using SymbolicUtils, Test -using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, arrterm, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize +using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, arrterm, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize, simplify using Base: Slice using SymbolicUtils: Sym, term, operation @@ -97,6 +97,14 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) # #417 @test isequal(Symbolics.scalarize(x', (1,1)), x[1]) + # #814 + @test isa(simplify(sum(b .* 1) ), Num) + @test isa(simplify(prod(x .+ 2.0) ), Num) + @test isa(simplify(mapreduce(x -> (x+1)/(0.1 + abs(x)), ^, u)), Num) # exponent(s) must not be Int until #455 is fixed + @test Symbolics.symtype(simplify(sum(b .* 1))) <: Real + @test Symbolics.symtype(simplify(prod(x .+ 2.0))) <: Real + @test Symbolics.symtype(simplify(mapreduce(x -> (x+1)/(0.1 + abs(x)), ^, u))) <: Real + # #483 # examples by @gronniger @variables A[1:2, 1:2]