Skip to content

Commit e850c31

Browse files
Move mean rule to a Statistics ruleset folder
1 parent 9b50626 commit e850c31

File tree

6 files changed

+39
-34
lines changed

6 files changed

+39
-34
lines changed

src/ChainRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ include("rulesets/Base/array.jl")
2929
include("rulesets/Base/broadcast.jl")
3030
include("rulesets/Base/mapreduce.jl")
3131

32+
include("rulesets/Statistics/statistics.jl")
33+
3234
include("rulesets/LinearAlgebra/utils.jl")
3335
include("rulesets/LinearAlgebra/blas.jl")
3436
include("rulesets/LinearAlgebra/dense.jl")

src/rulesets/Base/mapreduce.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,3 @@ function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
6262
∂x = Rule(ȳ -> 2.* x)
6363
return y, (DNERule(), ∂x)
6464
end
65-
66-
#####
67-
##### `mean`
68-
#####
69-
70-
_denom(x, dims::Colon) = length(x)
71-
_denom(x, dims::Integer) = size(x, dims)
72-
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
73-
74-
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
75-
76-
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
77-
_, dx = rrule(sum, x; dims=dims)
78-
n = _denom(x, dims)
79-
return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n)
80-
end
81-
82-
function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
83-
_, (_, dx) = rrule(sum, f, x)
84-
n = _denom(x, :)
85-
return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n))
86-
end

src/rulesets/Statistics/statistics.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#####
2+
##### `mean`
3+
#####
4+
5+
_denom(x, dims::Colon) = length(x)
6+
_denom(x, dims::Integer) = size(x, dims)
7+
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
8+
9+
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
10+
11+
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
12+
_, dx = rrule(sum, x; dims=dims)
13+
n = _denom(x, dims)
14+
return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n)
15+
end
16+
17+
function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
18+
_, (_, dx) = rrule(sum, f, x)
19+
n = _denom(x, :)
20+
return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n))
21+
end

test/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,5 @@
6363
x̄_fd = j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X)
6464
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
6565
end
66-
end
67-
@testset "mean" begin
68-
rng = MersenneTwister(999)
69-
n = 9
70-
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
71-
X = randn(rng, n, n)
72-
y, dX = rrule(mean, X; dims=1)
73-
= randn(rng, size(y))
74-
X̄_ad = dX(ȳ)
75-
X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X)
76-
@test X̄_ad X̄_fd rtol=1e-9 atol=1e-9
77-
end
66+
end # sum
7867
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "mean" begin
2+
rng = MersenneTwister(999)
3+
n = 9
4+
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
5+
X = randn(rng, n, n)
6+
y, dX = rrule(mean, X; dims=1)
7+
= randn(rng, size(y))
8+
X̄_ad = dX(ȳ)
9+
X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X)
10+
@test X̄_ad X̄_fd rtol=1e-9 atol=1e-9
11+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ include("test_util.jl")
2727
include(joinpath("rulesets", "Base", "broadcast.jl"))
2828
end
2929

30+
@testset "Statistics" begin
31+
include(joinpath("rulesets", "Statistics", "statistics.jl"))
32+
end
33+
3034
@testset "LinearAlgebra" begin
3135
include(joinpath("rulesets", "LinearAlgebra", "dense.jl"))
3236
include(joinpath("rulesets", "LinearAlgebra", "structured.jl"))

0 commit comments

Comments
 (0)