Skip to content

Commit 74ef7a4

Browse files
authored
Add rrule for map and expand the testing framework (#56)
This implements `rrule(map, f, xs...)` and expands `rrule_test` to allow non-differentiable arguments. For such cases, the user need only pass `nothing` as the argument's assumed sensitivity.
1 parent db1eb6b commit 74ef7a4

File tree

6 files changed

+87
-18
lines changed

6 files changed

+87
-18
lines changed

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include("rules.jl")
1313
include("rules/base.jl")
1414
include("rules/array.jl")
1515
include("rules/broadcast.jl")
16+
include("rules/mapreduce.jl")
1617
include("rules/linalg/utils.jl")
1718
include("rules/linalg/blas.jl")
1819
include("rules/linalg/dense.jl")

src/rules/linalg/dense.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ using LinearAlgebra: AbstractTriangular
44
# these we can use simpler definitions for `/` and `\`.
55
const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
66

7-
#####
8-
##### `sum`
9-
#####
10-
11-
frule(::typeof(sum), x) = (sum(x), Rule(sum))
12-
13-
rrule(::typeof(sum), x) = (sum(x), Rule(cast))
14-
157
#####
168
##### `dot`
179
#####

src/rules/mapreduce.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#####
2+
##### `map`
3+
#####
4+
5+
function rrule(::typeof(map), f, xs...)
6+
y = map(f, xs...)
7+
∂xs = ntuple(length(xs)) do i
8+
Rule() do
9+
map(ȳ, xs...) do ȳi, xis...
10+
r = rrule(f, xis...)
11+
if r === nothing
12+
throw(ArgumentError("can't differentiate `map` with `$f`; no `rrule` " *
13+
"is defined for `$f$xis`"))
14+
end
15+
_, ∂xis = r
16+
extern(∂xis[i](ȳi))
17+
end
18+
end
19+
end
20+
return y, (DNERule(), ∂xs...)
21+
end
22+
23+
#####
24+
##### `sum`
25+
#####
26+
27+
frule(::typeof(sum), x) = (sum(x), Rule(sum))
28+
29+
rrule(::typeof(sum), x) = (sum(x), Rule(cast))

test/rules/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "Maps and Reductions" begin
2+
@testset "map" begin
3+
rng = MersenneTwister(42)
4+
n = 10
5+
x = randn(rng, n)
6+
vx = randn(rng, n)
7+
= randn(rng, n)
8+
rrule_test(map, ȳ, (sin, nothing), (x, vx))
9+
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
10+
end
11+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ChainRules, Test, FDM, LinearAlgebra, Random
44
using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule,
55
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
66
Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted,
7-
DNE, Thunk, Casted
7+
DNE, Thunk, Casted, DNERule
88
using Base.Broadcast: broadcastable
99

1010
include("test_util.jl")
@@ -15,6 +15,7 @@ include("test_util.jl")
1515
@testset "rules" begin
1616
include(joinpath("rules", "base.jl"))
1717
include(joinpath("rules", "array.jl"))
18+
include(joinpath("rules", "mapreduce.jl"))
1819
@testset "linalg" begin
1920
include(joinpath("rules", "linalg", "dense.jl"))
2021
include(joinpath("rules", "linalg", "structured.jl"))

test/test_util.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,58 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
5050
test_accumulation(Zero(), dx, ȳ, x̄_ad)
5151
end
5252

53+
function _make_fdm_call(fdm, f, ȳ, xs, ignores)
54+
sig = Expr(:tuple)
55+
call = Expr(:call, f)
56+
newxs = Any[]
57+
arginds = Int[]
58+
i = 1
59+
for (x, ignore) in zip(xs, ignores)
60+
if ignore
61+
push!(call.args, x)
62+
else
63+
push!(call.args, Symbol(:x, i))
64+
push!(sig.args, Symbol(:x, i))
65+
push!(newxs, x)
66+
push!(arginds, i)
67+
end
68+
i += 1
69+
end
70+
fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...)))
71+
fd = eval(fdexpr)
72+
fd isa Tuple || (fd = (fd,))
73+
args = Any[nothing for _ in 1:length(xs)]
74+
for (dx, ind) in zip(fd, arginds)
75+
args[ind] = dx
76+
end
77+
return (args...,)
78+
end
79+
5380
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
5481
# Check correctness of evaluation.
5582
xs, x̄s = collect(zip(xx̄s...))
56-
Ω, Δx_rules = ChainRules.rrule(f, xs...)
57-
@test f(xs...) == Ω
83+
y, rules = rrule(f, xs...)
84+
@test f(xs...) == y
5885

5986
# Correctness testing via finite differencing.
60-
Δxs_ad = map(Δx_rule->Δx_rule(ȳ), Δx_rules)
61-
Δxs_fd = j′vp(fdm, f, ȳ, xs...)
62-
for (Δx_ad, Δx_fd) in zip(Δxs_ad, Δxs_fd)
63-
@test isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...)
87+
x̄s_ad = map(rules) do rule
88+
rule isa DNERule ? DNE() : rule(ȳ)
89+
end
90+
x̄s_fd = _make_fdm_call(fdm, f, ȳ, xs, x̄s .== nothing)
91+
for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd)
92+
if x̄_fd === nothing
93+
# The way we've structured the above, this tests that the rule is a DNERule
94+
@test x̄_ad isa DNE
95+
else
96+
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
97+
end
6498
end
6599

66100
# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
67-
for (x̄, Δx_rule, Δx_ad) in zip(x̄s, Δx_rules, Δxs_ad)
68-
test_accumulation(x̄, Δx_rule, ȳ, Δx_ad)
69-
test_accumulation(Zero(), Δx_rule, ȳ, Δx_ad)
101+
for (x̄, rule, x̄_ad) in zip(x̄s, rules, x̄s_ad)
102+
=== nothing && continue
103+
test_accumulation(x̄, rule, ȳ, x̄_ad)
104+
test_accumulation(Zero(), rule, ȳ, x̄_ad)
70105
end
71106
end
72107

0 commit comments

Comments
 (0)