@@ -416,15 +416,33 @@ function _accumulate_pairwise_small!(op, dest::AbstractArray{T}, itr, accv, w, i
416
416
end
417
417
end
418
418
419
+ """
420
+ Base.ConvertOp{T}(op)(x,y)
421
+
422
+ An operator which converts `x` and `y` to type `T` before performing the `op`.
423
+
424
+ The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
425
+ """
426
+ struct ConvertOp{T,O} <: Function
427
+ op:: O
428
+ end
429
+ ConvertOp {T} (op:: O ) where {T,O} = ConvertOp {T,O} (op)
430
+ (c:: ConvertOp{T} )(x,y) where {T} = c. op (convert (T,x),convert (T,y))
431
+ reduce_first (c:: ConvertOp{T} ,x) where {T} = reduce_first (c. op, convert (T,x))
432
+
433
+
419
434
420
435
421
- function cumsum! (out, v:: AbstractVector{T} ) where T
436
+ function cumsum! (out:: AbstractVector , v:: AbstractVector{T} ) where T
422
437
# we dispatch on the possibility of numerical accuracy issues
423
438
cumsum! (out, v, ArithmeticStyle (T))
424
439
end
425
- cumsum! (out, v:: AbstractVector , :: ArithmeticRounds ) = accumulate_pairwise! (+ , out, v)
426
- cumsum! (out, v:: AbstractVector , :: ArithmeticUnknown ) = accumulate_pairwise! (+ , out, v)
427
- cumsum! (out, v:: AbstractVector , :: ArithmeticStyle ) = accumulate! (+ , out, v)
440
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticRounds ) where {T} =
441
+ accumulate_pairwise! (ConvertOp {T} (+ ), out, v)
442
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticUnknown ) where {T} =
443
+ accumulate_pairwise! (ConvertOp {T} (+ ), out, v)
444
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticStyle ) where {T} =
445
+ accumulate! (ConvertOp {T} (+ ), out, v)
428
446
429
447
"""
430
448
cumsum(A, dim::Integer)
@@ -488,14 +506,14 @@ cumsum(v::AbstractVector, ::ArithmeticStyle) = accumulate(add_sum, v)
488
506
489
507
Cumulative sum of `A` along the dimension `dim`, storing the result in `B`. See also [`cumsum`](@ref).
490
508
"""
491
- cumsum! (dest, A, dim:: Integer ) = accumulate! (+ , dest, A, dim)
509
+ cumsum! (dest:: AbstractArray{T} , A, dim:: Integer ) where {T} = accumulate! (ConvertOp {T} ( + ) , dest, A, dim)
492
510
493
511
"""
494
512
cumsum!(y::AbstractVector, x::AbstractVector)
495
513
496
514
Cumulative sum of a vector `x`, storing the result in `y`. See also [`cumsum`](@ref).
497
515
"""
498
- cumsum! (dest, itr) = accumulate! (+ , dest, src)
516
+ cumsum! (dest:: AbstractArray{T} , itr) where {T} = accumulate! (ConvertOp {T} ( + ) , dest, src)
499
517
500
518
"""
501
519
cumprod(A, dim::Integer)
@@ -555,12 +573,12 @@ cumprod(x::AbstractVector) = accumulate(mul_prod, x)
555
573
Cumulative product of `A` along the dimension `dim`, storing the result in `B`.
556
574
See also [`cumprod`](@ref).
557
575
"""
558
- cumprod! (dest, A, dim:: Integer ) = accumulate! (* , dest, A, dim)
576
+ cumprod! (dest:: AbstractArray{T} , A, dim:: Integer ) where {T} = accumulate! (ConvertOp {T} ( * ) , dest, A, dim)
559
577
560
578
"""
561
579
cumprod!(y::AbstractVector, x::AbstractVector)
562
580
563
581
Cumulative product of a vector `x`, storing the result in `y`.
564
582
See also [`cumprod`](@ref).
565
583
"""
566
- cumprod! (dest, itr) = accumulate! (* , dest, itr)
584
+ cumprod! (dest:: AbstractArray{T} , itr) where {T} = accumulate! (ConvertOp {T} ( * ) , dest, itr)
0 commit comments