@@ -33,7 +33,7 @@ function accumulate_pairwise!(op::Op, result::AbstractVector, v::AbstractVector)
33
33
end
34
34
35
35
function accumulate_pairwise (op, v:: AbstractVector{T} ) where T
36
- out = similar (v, promote_op (op, T, T ))
36
+ out = similar (v, _accumulate_promote_op (op, v ))
37
37
return accumulate_pairwise! (op, out, v)
38
38
end
39
39
@@ -111,8 +111,8 @@ julia> cumsum(a, dims=2)
111
111
widening happens and integer overflow results in `Int8[100, -128]`.
112
112
"""
113
113
function cumsum (A:: AbstractArray{T} ; dims:: Integer ) where T
114
- out = similar (A, promote_op (add_sum, T, T ))
115
- cumsum! (out, A, dims= dims)
114
+ out = similar (A, _accumulate_promote_op (add_sum, A ))
115
+ return cumsum! (out, A, dims= dims)
116
116
end
117
117
118
118
"""
@@ -280,14 +280,13 @@ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...)
280
280
# This branch takes care of the cases not handled by `_accumulate!`.
281
281
return collect (Iterators. accumulate (op, A; kw... ))
282
282
end
283
+
283
284
nt = values (kw)
284
- if isempty (kw)
285
- out = similar (A, promote_op (op, eltype (A), eltype (A)))
286
- elseif keys (nt) === (:init ,)
287
- out = similar (A, promote_op (op, typeof (nt. init), eltype (A)))
288
- else
285
+ if ! (isempty (kw) || keys (nt) === (:init ,))
289
286
throw (ArgumentError (" accumulate does not support the keyword arguments $(setdiff (keys (nt), (:init ,))) " ))
290
287
end
288
+
289
+ out = similar (A, _accumulate_promote_op (op, A; kw... ))
291
290
accumulate! (op, out, A; dims= dims, kw... )
292
291
end
293
292
@@ -442,3 +441,42 @@ function _accumulate1!(op, B, v1, A::AbstractVector, dim::Integer)
442
441
end
443
442
return B
444
443
end
444
+
445
+ # Internal function used to identify the widest possible eltype required for accumulate results
446
+ function _accumulate_promote_op (op, v; init= nothing )
447
+ # Nested mock functions used to infer the widest necessary eltype
448
+ # NOTE: We are just passing this to promote_op for inference and should never be run.
449
+
450
+ # Initialization function used to identify initial type of `r`
451
+ # NOTE: reduce_first may have a different return type than calling `op`
452
+ function f (op, v, init)
453
+ val = first (something (iterate (v)))
454
+ return isnothing (init) ? Base. reduce_first (op, val) : op (init, val)
455
+ end
456
+
457
+ # Infer iteration type independent of the initialization type
458
+ # If `op` fails then this will return `Union{}` as `k` will be undefined.
459
+ # Returning `Union{}` is desirable as it won't break the `promote_type` call in the
460
+ # outer scope below
461
+ function g (op, v, r)
462
+ local k
463
+ for val in v
464
+ k = op (r, val)
465
+ end
466
+ return k
467
+ end
468
+
469
+ # Finally loop again with the two types promoted together
470
+ # If the `op` fails and reduce_first was used then then this will still just
471
+ # return the initial type, allowing the `op` to error during execution.
472
+ function h (op, v, r)
473
+ for val in v
474
+ r = op (r, val)
475
+ end
476
+ return r
477
+ end
478
+
479
+ R = Base. promote_op (f, typeof (op), typeof (v), typeof (init))
480
+ K = Base. promote_op (g, typeof (op), typeof (v), R)
481
+ return Base. promote_op (h, typeof (op), typeof (v), Base. promote_type (R, K))
482
+ end
0 commit comments