Skip to content

Commit 829f433

Browse files
authored
Merge pull request #443 from mcabbott/variance
Use lazy broadcasting for `Statistics.var`
2 parents 58ef36a + 077f26a commit 829f433

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

src/host/statistics.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ using Statistics
22

33
function Statistics.varm(A::AbstractGPUArray{<:Real}, M::AbstractArray{<:Real};
44
dims, corrected::Bool=true)
5+
T = float(eltype(A))
6+
λ = convert(T, inv(_mean_denom(A, dims) - corrected))
57
#B = (A .- M).^2
68
# NOTE: the above broadcast promotes to Float64 and uses power_by_squaring...
7-
B = broadcast(A, M) do a, m
9+
B = Broadcast.broadcasted(A, M) do a, m
810
x = (a - m)
9-
x*x
11+
λ * x * x
1012
end
11-
sum(B, dims=dims)/(prod(size(A)[[dims...]])::Int-corrected)
13+
sum(Broadcast.instantiate(B); dims)
1214
end
1315

1416
Statistics.stdm(A::AbstractGPUArray{<:Real},m::AbstractArray{<:Real}, dim::Int; corrected::Bool=true) =
@@ -23,8 +25,17 @@ Statistics._std(A::AbstractGPUArray, corrected::Bool, mean, ::Colon) =
2325
# Revert https://github.com/JuliaLang/Statistics.jl/pull/25
2426
Statistics._mean(A::AbstractGPUArray, ::Colon) = sum(A) / length(A)
2527
Statistics._mean(f, A::AbstractGPUArray, ::Colon) = sum(f, A) / length(A)
26-
Statistics._mean(A::AbstractGPUArray, dims) = mean!(Base.reducedim_init(t -> t/2, +, A, dims), A)
27-
Statistics._mean(f, A::AbstractGPUArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1)
28+
29+
function Statistics._mean(A::AbstractGPUArray, dims)
30+
T = float(eltype(A))
31+
λ = convert(T, inv(_mean_denom(A, dims)))
32+
sum(Base.Fix1(*,λ), A; dims)
33+
end
34+
function Statistics._mean(f, A::AbstractGPUArray, dims)
35+
T = float(eltype(A))
36+
λ = convert(T, inv(_mean_denom(A, dims)))
37+
sum(Base.Fix1(*,λ) f, A; dims)
38+
end
2839

2940
function Statistics.covzm(x::AbstractGPUMatrix, vardim::Int=1; corrected::Bool=true)
3041
C = Statistics.unscaled_covzm(x, vardim)
@@ -49,3 +60,7 @@ function Statistics.corzm(x::AbstractGPUMatrix, vardim::Int=1)
4960
c = Statistics.unscaled_covzm(x, vardim)
5061
return Statistics.cov2cor!(c, sqrt.(diag(c)))
5162
end
63+
64+
_mean_denom(x::AbstractArray, dims::Integer) = size(x, dims)
65+
_mean_denom(x::AbstractArray, dims::Colon) = length(x)
66+
_mean_denom(x::AbstractArray, dims) = prod(size(x,d) for d in unique(dims); init=1)

0 commit comments

Comments
 (0)