Skip to content

Commit 5fb2923

Browse files
committed
Add rrule for mean
1 parent fc555b1 commit 5fb2923

File tree

5 files changed

+36
-1
lines changed

5 files changed

+36
-1
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
99
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
10+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1011

1112
[compat]
1213
Cassette = "^0.2"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ChainRules
33
using Cassette
44
using LinearAlgebra
55
using LinearAlgebra.BLAS
6+
using Statistics
67
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
78

89
if VERSION < v"1.3.0-DEV.142"

src/rules/mapreduce.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,25 @@ function rrule(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}; dims=:)
6262
∂x = Rule(ȳ -> 2.* x)
6363
return y, (DNERule(), ∂x)
6464
end
65+
66+
#####
67+
##### `mean`
68+
#####
69+
70+
_denom(x, dims::Colon) = length(x)
71+
_denom(x, dims::Integer) = size(x, dims)
72+
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
73+
74+
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
75+
76+
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
77+
_, dx = rrule(sum, x; dims=dims)
78+
n = _denom(x, dims)
79+
return mean(x; dims=dims), Rule(ȳ -> dx(ȳ) / n)
80+
end
81+
82+
function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
83+
_, (_, dx) = rrule(sum, f, x)
84+
n = _denom(x, :)
85+
return mean(f, x), (DNERule(), Rule(ȳ -> dx(ȳ) / n))
86+
end

test/rules/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,15 @@
6464
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
6565
end
6666
end
67+
@testset "mean" begin
68+
rng = MersenneTwister(999)
69+
n = 9
70+
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
71+
X = randn(rng, n, n)
72+
y, dX = rrule(mean, X; dims=1)
73+
= randn(rng, size(y))
74+
X̄_ad = dX(ȳ)
75+
X̄_fd = j′vp(central_fdm(5, 1), x->mean(x, dims=1), ȳ, X)
76+
@test X̄_ad X̄_fd rtol=1e-9 atol=1e-9
77+
end
6778
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TODO: more tests!
22

3-
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random
3+
using ChainRules, Test, FDM, LinearAlgebra, LinearAlgebra.BLAS, Random, Statistics
44
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
55
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
66
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,

0 commit comments

Comments
 (0)