Skip to content

Commit e457a00

Browse files
committed
use lazy broadcasting for var
1 parent bf72dc8 commit e457a00

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/host/statistics.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ using Statistics
22

33
function Statistics.varm(A::AbstractGPUArray{<:Real}, M::AbstractArray{<:Real};
44
dims, corrected::Bool=true)
5+
T = float(eltype(A))
6+
λ = convert(T, inv(_mean_denom(A, dims) - corrected))
57
#B = (A .- M).^2
68
# NOTE: the above broadcast promotes to Float64 and uses power_by_squaring...
7-
B = broadcast(A, M) do a, m
9+
B = Broadcast.broadcasted(A, M) do a, m
810
x = (a - m)
9-
x*x
11+
λ * x * x
1012
end
11-
sum(B, dims=dims)/(prod(size(A)[[dims...]])::Int-corrected)
13+
sum(Broadcast.instantiate(B); dims)
1214
end
1315

1416
Statistics.stdm(A::AbstractGPUArray{<:Real},m::AbstractArray{<:Real}, dim::Int; corrected::Bool=true) =
@@ -49,3 +51,7 @@ function Statistics.corzm(x::AbstractGPUMatrix, vardim::Int=1)
4951
c = Statistics.unscaled_covzm(x, vardim)
5052
return Statistics.cov2cor!(c, sqrt.(diag(c)))
5153
end
54+
55+
_mean_denom(x::AbstractArray, dims::Integer) = size(x, dims)
56+
_mean_denom(x::AbstractArray, dims::Colon) = length(x)
57+
_mean_denom(x::AbstractArray, dims) = prod(size(x,d) for d in unique(dims); init=1)

0 commit comments

Comments
 (0)