Skip to content

Commit cc8b9ea

Browse files
mcabbottmzgubic
andauthored
Rule for mean(f,x) (#615)
* mean(f,x) rule * Revert changes to sum_f_x * simpler rule, stiffer test * Apply 2 suggestions Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent 30431b4 commit cc8b9ea

File tree

3 files changed

+71
-11
lines changed

3 files changed

+71
-11
lines changed

src/rulesets/Statistics/statistics.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,31 @@ _denom(x, dims::Colon) = length(x)
66
_denom(x, dims::Integer) = size(x, dims)
77
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
88

9-
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
10-
# https://github.com/JuliaDiff/ChainRules.jl/issues/85
11-
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
12-
y_sum, sum_pullback = rrule(sum, x; dims=dims)
9+
function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:)
10+
y_sum, sum_pullback = rrule(sum, x; dims)
1311
n = _denom(x, dims)
1412
function mean_pullback(ȳ)
15-
_, ∂sum_x = sum_pullback(ȳ)
16-
∂x = unthunk(∂sum_x) / n
13+
_, ∂x = sum_pullback(unthunk(ȳ) / n)
1714
return (NoTangent(), ∂x)
1815
end
1916
return y_sum / n, mean_pullback
2017
end
2118

19+
function rrule(
20+
config::RuleConfig{>:HasReverseMode},
21+
::typeof(mean),
22+
f::F,
23+
x::AbstractArray{T};
24+
dims=:,
25+
) where {F, T<:Union{Real,Complex,AbstractArray}}
26+
y_sum, sum_pullback = rrule(config, sum, f, x; dims)
27+
n = _denom(x, dims)
28+
function mean_pullback_f(ȳ)
29+
return sum_pullback(unthunk(ȳ) / n)
30+
end
31+
return y_sum / n, mean_pullback_f
32+
end
33+
2234
#####
2335
##### variance
2436
#####

test/rulesets/Statistics/statistics.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
@testset "mean" begin
2-
n = 9
3-
@testset "Basic" begin
4-
test_rrule(mean, randn(n))
2+
@testset "mean(x)" begin
3+
test_rrule(mean, randn(9))
4+
test_rrule(mean, randn(ComplexF64,2,4))
5+
test_rrule(mean, transpose(rand(3)))
6+
test_rrule(mean, [rand(3) for _ in 1:4]; check_inferred=false)
57
end
68
@testset "with dims kwargs" begin
7-
test_rrule(mean, randn(n); fkwargs=(;dims=1))
8-
test_rrule(mean, randn(n,4); fkwargs=(;dims=2))
9+
test_rrule(mean, randn(9); fkwargs=(;dims=1))
10+
test_rrule(mean, randn(9,4); fkwargs=(;dims=2))
11+
test_rrule(mean, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(;dims=2), check_inferred=false)
12+
end
13+
@testset "mean(f, x)" begin
14+
# This shares its implementation with sum(f, x). Similar tests should cover all cases:
15+
test_rrule(mean, abs, [-4.0, 2.0, 2.0])
16+
test_rrule(mean, log, rand(3, 4) .+ 1)
17+
test_rrule(mean, cbrt, randn(5))
18+
test_rrule(mean, Multiplier(2.0), [2.0, 4.0, 8.0]) # defined in test_helpers.jl
19+
test_rrule(mean, Divider(1 + rand()), randn(5))
20+
21+
test_rrule(mean, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
22+
23+
test_rrule(mean, log, rand(ComplexF64, 5))
24+
test_rrule(mean, sqrt, rand(ComplexF64, 5))
25+
test_rrule(mean, abs, rand(ComplexF64, 3, 4))
26+
27+
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
28+
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
29+
test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
930
end
1031
end
1132

test/test_helpers.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,28 @@ function ChainRulesCore.rrule(m::Multiplier, y, z)
3535
return m(y, z), Multiplier_pullback_3
3636
end
3737

38+
"""
39+
Divider(x)
40+
41+
Stores a fixed `x` and divides by it, then squares the result.
42+
43+
Especially for testing the gradient of higher order functions with respect to `x`.
44+
```
45+
julia> map(Divider(2), [1 2 3 4 10])
46+
1×5 Matrix{Float64}:
47+
0.25 1.0 2.25 4.0 25.0
48+
```
49+
"""
50+
struct Divider{T<:Real}
51+
x::T
52+
end
53+
(d::Divider)(y::Real) = (y / d.x)^2
54+
55+
function ChainRulesCore.rrule(d::Divider, y::Real)
56+
Divider_pullback(dΩ) = (Tangent{typeof(d)}(; x = -2 ** y^2 / d.x^3), 2 ** y / d.x^2)
57+
return d(y), Divider_pullback
58+
end
59+
3860
"""
3961
Counter()
4062
@@ -88,6 +110,11 @@ end
88110
test_rrule(Multiplier(1.0 + 2im), 3.0 + 4im, 5.0 - 6im)
89111
test_rrule(Multiplier(rand(2,3)), rand(3,4), rand(4,5))
90112
end
113+
114+
@testset "Divider" begin
115+
test_rrule(Divider(2.3), 4.5)
116+
test_rrule(Divider(0.2), -3.4)
117+
end
91118

92119
@testset "Counter" begin
93120
c = Counter()

0 commit comments

Comments
 (0)