Skip to content

Commit 077f26a

Browse files
committed
allocate once for mean(x; dims)
1 parent e457a00 commit 077f26a

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/host/statistics.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,17 @@ Statistics._std(A::AbstractGPUArray, corrected::Bool, mean, ::Colon) =
2525
# Revert https://github.com/JuliaLang/Statistics.jl/pull/25
2626
Statistics._mean(A::AbstractGPUArray, ::Colon) = sum(A) / length(A)
2727
Statistics._mean(f, A::AbstractGPUArray, ::Colon) = sum(f, A) / length(A)
28-
Statistics._mean(A::AbstractGPUArray, dims) = mean!(Base.reducedim_init(t -> t/2, +, A, dims), A)
29-
Statistics._mean(f, A::AbstractGPUArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1)
28+
29+
function Statistics._mean(A::AbstractGPUArray, dims)
30+
T = float(eltype(A))
31+
λ = convert(T, inv(_mean_denom(A, dims)))
32+
sum(Base.Fix1(*,λ), A; dims)
33+
end
34+
function Statistics._mean(f, A::AbstractGPUArray, dims)
35+
T = float(eltype(A))
36+
λ = convert(T, inv(_mean_denom(A, dims)))
37+
sum(Base.Fix1(*,λ) f, A; dims)
38+
end
3039

3140
function Statistics.covzm(x::AbstractGPUMatrix, vardim::Int=1; corrected::Bool=true)
3241
C = Statistics.unscaled_covzm(x, vardim)

0 commit comments

Comments
 (0)