Skip to content

Commit 4023ec0

Browse files
mcabbottoxinabox
andauthored
Rules for var, std (#560)
* rules for var, std * Apply 2 suggestions Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent cb816d0 commit 4023ec0

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/rulesets/Statistics/statistics.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,39 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
1818
end
1919
return y_sum / n, mean_pullback
2020
end
21+
22+
#####
23+
##### variance
24+
#####
25+
26+
function rrule(
27+
::typeof(Statistics.var),
28+
x::AbstractArray{<:Number};
29+
corrected::Bool=true,
30+
dims=:,
31+
mean=mean(x, dims=dims)
32+
)
33+
y = Statistics.var(x; corrected=corrected, mean=mean, dims=dims)
34+
function variance_pullback(dy)
35+
pre = 2 // (_denom(x, dims) - corrected)
36+
dx = pre .* unthunk(dy) .* (x .- mean)
37+
return (NoTangent(), ProjectTo(x)(dx))
38+
end
39+
y, variance_pullback
40+
end
41+
42+
function rrule(
43+
::typeof(Statistics.std),
44+
x::AbstractArray{<:Number};
45+
corrected::Bool=true,
46+
dims=:,
47+
mean=mean(x, dims=dims)
48+
)
49+
y = Statistics.std(x; corrected=corrected, mean=mean, dims=dims)
50+
function std_pullback(dy)
51+
pre = 1 // (_denom(x, dims) - corrected)
52+
dx = pre .* unthunk(dy) .* (x .- mean) ./ y
53+
return (NoTangent(), ProjectTo(x)(dx))
54+
end
55+
y, std_pullback
56+
end

test/rulesets/Statistics/statistics.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,18 @@
88
test_rrule(mean, randn(n,4); fkwargs=(;dims=2))
99
end
1010
end
11+
12+
@testset "variation: $var" for var in (std, var)
13+
test_rrule(var, randn(3))
14+
test_rrule(var, randn(4, 5); fkwargs=(; corrected=false))
15+
test_rrule(var, randn(ComplexF64, 6))
16+
test_rrule(var, Diagonal(randn(6)))
17+
18+
test_rrule(var, randn(4, 5); fkwargs=(; dims=1))
19+
test_rrule(var, randn(ComplexF64, 4, 5); fkwargs=(; dims=2, corrected=false))
20+
test_rrule(var, UpperTriangular(randn(5, 5)); fkwargs=(; dims=1))
21+
22+
x = PermutedDimsArray(randn(3, 4, 5), (3, 2, 1))
23+
xm = mean(x; dims=(1, 3))
24+
test_rrule(var, x; fkwargs=(; dims=(1, 3), mean=xm))
25+
end

0 commit comments

Comments
 (0)