Skip to content

Commit c5f5d47

Browse files
committed
Use destination type to determine output of cumsum! and cumprod!
1 parent 512fbcd commit c5f5d47

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

base/accumulate.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,33 @@ function _accumulate_pairwise_small!(op, dest::AbstractArray{T}, itr, accv, w, i
416416
end
417417
end
418418

419+
"""
420+
Base.ConvertOp{T}(op)(x,y)
421+
422+
An operator which converts `x` and `y` to type `T` before performing the `op`.
423+
424+
The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
425+
"""
426+
struct ConvertOp{T,O} <: Function
427+
op::O
428+
end
429+
ConvertOp{T}(op::O) where {T,O} = ConvertOp{T,O}(op)
430+
(c::ConvertOp{T})(x,y) where {T} = c.op(convert(T,x),convert(T,y))
431+
reduce_first(c::ConvertOp{T},x) where {T} = reduce_first(c.op, convert(T,x))
432+
433+
419434

420435

421-
function cumsum!(out, v::AbstractVector{T}) where T
436+
function cumsum!(out::AbstractVector, v::AbstractVector{T}) where T
422437
# we dispatch on the possibility of numerical accuracy issues
423438
cumsum!(out, v, ArithmeticStyle(T))
424439
end
425-
cumsum!(out, v::AbstractVector, ::ArithmeticRounds) = accumulate_pairwise!(+, out, v)
426-
cumsum!(out, v::AbstractVector, ::ArithmeticUnknown) = accumulate_pairwise!(+, out, v)
427-
cumsum!(out, v::AbstractVector, ::ArithmeticStyle) = accumulate!(+, out, v)
440+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticRounds) where {T} =
441+
accumulate_pairwise!(ConvertOp{T}(+), out, v)
442+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticUnknown) where {T} =
443+
accumulate_pairwise!(ConvertOp{T}(+), out, v)
444+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticStyle) where {T} =
445+
accumulate!(ConvertOp{T}(+), out, v)
428446

429447
"""
430448
cumsum(A, dim::Integer)
@@ -488,14 +506,14 @@ cumsum(v::AbstractVector, ::ArithmeticStyle) = accumulate(add_sum, v)
488506
489507
Cumulative sum of `A` along the dimension `dim`, storing the result in `B`. See also [`cumsum`](@ref).
490508
"""
491-
cumsum!(dest, A, dim::Integer) = accumulate!(+, dest, A, dim)
509+
cumsum!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(+), dest, A, dim)
492510

493511
"""
494512
cumsum!(y::AbstractVector, x::AbstractVector)
495513
496514
Cumulative sum of a vector `x`, storing the result in `y`. See also [`cumsum`](@ref).
497515
"""
498-
cumsum!(dest, itr) = accumulate!(+, dest, src)
516+
cumsum!(dest::AbstractArray{T}, itr) where {T} = accumulate!(ConvertOp{T}(+), dest, src)
499517

500518
"""
501519
cumprod(A, dim::Integer)
@@ -555,12 +573,12 @@ cumprod(x::AbstractVector) = accumulate(mul_prod, x)
555573
Cumulative product of `A` along the dimension `dim`, storing the result in `B`.
556574
See also [`cumprod`](@ref).
557575
"""
558-
cumprod!(dest, A, dim::Integer) = accumulate!(*, dest, A, dim)
576+
cumprod!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(*), dest, A, dim)
559577

560578
"""
561579
cumprod!(y::AbstractVector, x::AbstractVector)
562580
563581
Cumulative product of a vector `x`, storing the result in `y`.
564582
See also [`cumprod`](@ref).
565583
"""
566-
cumprod!(dest, itr) = accumulate!(*, dest, itr)
584+
cumprod!(dest::AbstractArray{T}, itr) where {T} = accumulate!(ConvertOp{T}(*), dest, itr)

base/reduce.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_sum(x::SmallSigned) = Int(x)
2828
add_sum(x::SmallUnsigned) = UInt(x)
2929
add_sum(X::AbstractArray) = broadcast(add_sum, X)
3030

31+
3132
"""
3233
Base.mul_prod(x,y)
3334

0 commit comments

Comments
 (0)