Skip to content

Commit 0c7884d

Browse files
authored
Make work on functors (#170)
1 parent ac787e9 commit 0c7884d

File tree

5 files changed

+140
-68
lines changed

5 files changed

+140
-68
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.7.6"
3+
version = "0.7.7"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -13,5 +13,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1313
[compat]
1414
ChainRulesCore = "0.10"
1515
Compat = "3"
16-
FiniteDifferences = "0.12"
16+
FiniteDifferences = "0.12.12"
1717
julia = "1"

docs/Manifest.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5"
14+
git-tree-sha1 = "04dd5ce9f9d7b9b14559b00a7eb5be7528f56b82"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.10.1"
16+
version = "0.10.2"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2020
path = ".."
2121
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.7.0"
22+
version = "0.7.5"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
5757

5858
[[FiniteDifferences]]
5959
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60-
git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1"
60+
git-tree-sha1 = "5d448db3b862fb331d20144c2e59c54db69720e0"
6161
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62-
version = "0.12.10"
62+
version = "0.12.12"
6363

6464
[[IOCapture]]
6565
deps = ["Logging"]

docs/src/index.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,29 @@ Test Summary: | Pass Total
105105
test_scalar: relu at -0.5 | 9 9
106106
```
107107

108+
## Testing constructors and functors (callable objects)
109+
110+
Testing constructor and functors works as you would expect. For struct `Foo`
111+
```julia
112+
struct Foo
113+
a::Float64
114+
end
115+
(f::Foo)(x) = return f.a + x
116+
Base.length(::Foo) = 1
117+
Base.iterate(f::Foo) = iterate(f.a)
118+
Base.iterate(f::Foo, state) = iterate(f.a, state)
119+
```
120+
the `f/rrule`s can be tested by
121+
```julia
122+
test_rrule(Foo, rand()) # constructor
123+
124+
foo = Foo(rand())
125+
test_rrule(foo, rand()) # functor
126+
127+
# it is also possible to provide tangents for `foo` explicitly
128+
test_frule(foo Tangent{Foo}(;a=rand()), rand())
129+
```
130+
108131
## Specifying Tangents
109132
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
110133
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
@@ -152,7 +175,7 @@ which should have passed the test.
152175

153176
By default, all functions for testing rules check whether the output type (as well as that of the pullback for `rrule`s) can be completely inferred, such that everything is type stable:
154177

155-
```jldoctest ex
178+
```julia
156179
julia> function ChainRulesCore.rrule(::typeof(abs), x)
157180
abs_pullback(Δ) = (NoTangent(), x >= 0 ? Δ : big(-1.0) * Δ)
158181
return abs(x), abs_pullback
@@ -167,7 +190,7 @@ test_rrule: abs on Float64: Error During Test at /home/runner/work/ChainRulesTes
167190

168191
This can be disabled on a per-rule basis using the `check_inferred` keyword argument:
169192

170-
```jldoctest ex
193+
```julia
171194
julia> test_rrule(abs, 1.; check_inferred=false)
172195
Test Summary: | Pass Total
173196
test_rrule: abs on Float64 | 5 5

src/testers.jl

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
1818
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
1919

2020
@testset "test_scalar: $f at $z" begin
21-
_ensure_not_running_on_functor(f, "test_scalar")
2221
# z = x + im * y
2322
# Ω = u(x, y) + im * v(x, y)
2423
Ω = f(z; fkwargs...)
@@ -30,8 +29,9 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
3029
test_frule(f, z Δx; rule_test_kwargs...)
3130
if z isa Complex
3231
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
33-
_, real_tangent = frule((ZeroTangent(), real(Δx)), f, z; fkwargs...)
34-
_, embedded_tangent = frule((ZeroTangent(), Δx), f, z; fkwargs...)
32+
= rand_tangent(f)
33+
_, real_tangent = frule((ḟ, real(Δx)), f, z; fkwargs...)
34+
_, embedded_tangent = frule((ḟ, Δx), f, z; fkwargs...)
3535
test_approx(real_tangent, embedded_tangent; isapprox_kwargs...)
3636
end
3737
end
@@ -70,7 +70,7 @@ end
7070
test_frule(f, args..; kwargs...)
7171
7272
# Arguments
73-
- `f`: Function for which the `frule` should be tested.
73+
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
7474
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
7575
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7676
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
@@ -99,25 +99,29 @@ function test_frule(
9999
# To simplify some of the calls we make later lets group the kwargs for reuse
100100
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
101101

102+
# and define a helper closure
103+
call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)
104+
102105
@testset "test_frule: $f on $(_string_typeof(args))" begin
103-
_ensure_not_running_on_functor(f, "test_frule")
104106

105-
xẋs = auto_primal_and_tangent.(args)
106-
xs = primal.(xẋs)
107-
ẋs = tangent.(xẋs)
108-
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
109-
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
107+
primals_and_tangents = auto_primal_and_tangent.((f, args...))
108+
primals = primal.(primals_and_tangents)
109+
tangents = tangent.(primals_and_tangents)
110+
111+
if check_inferred && _is_inferrable(deepcopy(primals)...; deepcopy(fkwargs)...)
112+
_test_inferred(frule, deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
110113
end
111-
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
112-
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
114+
115+
res = frule(deepcopy(tangents), deepcopy(primals)...; deepcopy(fkwargs)...)
116+
res === nothing && throw(MethodError(frule, typeof(primals)))
113117
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
114118
Ω_ad, dΩ_ad = res
115-
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
119+
Ω = call_on_copy(primals...)
116120
test_approx(Ω_ad, Ω; isapprox_kwargs...)
117121

118122
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
119-
ẋs_is_ignored = isa.(ẋs, Union{Nothing,NoTangent})
120-
if any(ẋs .== nothing)
123+
is_ignored = isa.(tangents, Union{Nothing,NoTangent})
124+
if any(tangents .== nothing)
121125
Base.depwarn(
122126
"test_frule(f, k ⊢ nothing) is deprecated, use " *
123127
"test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
@@ -126,7 +130,7 @@ function test_frule(
126130
end
127131

128132
# Correctness testing via finite differencing.
129-
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
133+
dΩ_fd = _make_jvp_call(fdm, call_on_copy, Ω, primals, tangents, is_ignored)
130134
test_approx(dΩ_ad, dΩ_fd; isapprox_kwargs...)
131135

132136
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
@@ -138,14 +142,14 @@ end
138142
test_rrule(f, args...; kwargs...)
139143
140144
# Arguments
141-
- `f`: Function to which rule should be applied.
142-
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ `
145+
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
146+
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ `
143147
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144148
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145149
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
146150
147151
# Keyword Arguments
148-
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
152+
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
149153
should be a differential for the output of `f`. Is set automatically if not provided.
150154
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
151155
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
@@ -167,63 +171,66 @@ function test_rrule(
167171
# To simplify some of the calls we make later lets group the kwargs for reuse
168172
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
169173

174+
# and define helper closure over fkwargs
175+
call(f, xs...) = f(xs...; fkwargs...)
176+
170177
@testset "test_rrule: $f on $(_string_typeof(args))" begin
171-
_ensure_not_running_on_functor(f, "test_rrule")
172178

173179
# Check correctness of evaluation.
174-
xx̄s = auto_primal_and_tangent.(args)
175-
xs = primal.(xx̄s)
176-
accumulated_x̄ = tangent.(xx̄s)
177-
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
178-
_test_inferred(rrule, f, xs...; fkwargs...)
180+
primals_and_tangents = auto_primal_and_tangent.((f, args...))
181+
primals = primal.(primals_and_tangents)
182+
accum_cotangents = tangent.(primals_and_tangents)
183+
184+
if check_inferred && _is_inferrable(primals...; fkwargs...)
185+
_test_inferred(rrule, primals...; fkwargs...)
179186
end
180-
res = rrule(f, xs...; fkwargs...)
181-
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
187+
res = rrule(primals...; fkwargs...)
188+
res === nothing && throw(MethodError(rrule, typeof((primals...))))
182189
y_ad, pullback = res
183-
y = f(xs...; fkwargs...)
190+
y = call(primals...)
184191
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
185192

186193
= output_tangent isa Auto ? rand_tangent(y) : output_tangent
187194

188195
check_inferred && _test_inferred(pullback, ȳ)
189-
∂s = pullback(ȳ)
190-
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
191-
∂self = ∂s[1]
192-
x̄s_ad = ∂s[2:end]
193-
@test ∂self === NoTangent() # No internal fields
194-
msg = "The pullback should return 1 cotangent for each primal input."
195-
@test_msg msg length(x̄s_ad) == length(args)
196+
ad_cotangents = pullback(ȳ)
197+
ad_cotangents isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
198+
msg = "The pullback should return 1 cotangent for the primal and each primal input."
199+
@test_msg msg length(ad_cotangents) == 1 + length(args)
196200

197201
# Correctness testing via finite differencing.
198202
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
199-
x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing,NoTangent})
200-
if any(accumulated_x̄ .== nothing)
203+
is_ignored = isa.(accum_cotangents, Union{Nothing, NoTangent})
204+
if any(accum_cotangents .== nothing)
201205
Base.depwarn(
202206
"test_rrule(f, k ⊢ nothing) is deprecated, use " *
203207
"test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks",
204208
:test_rrule,
205209
)
206210
end
207211

208-
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
209-
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
210-
if accumulated_x̄ isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
211-
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
212-
x̄_ad isa ZeroTangent && error(
213-
"The pullback in the rrule for $f function should use NoTangent()" *
212+
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
213+
214+
for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
215+
accum_cotangents, ad_cotangents, fd_cotangents
216+
)
217+
if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
218+
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
219+
ad_cotangent isa ZeroTangent && error(
220+
"The pullback in the rrule should use NoTangent()" *
214221
" rather than ZeroTangent() for non-perturbable arguments.",
215222
)
216-
@test x̄_ad isa NoTangent # we said it wasn't differentiable.
223+
@test ad_cotangent isa NoTangent # we said it wasn't differentiable.
217224
else
218-
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)
225+
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
219226

220-
# The main test of the actual deriviative being correct:
221-
test_approx(x̄_ad, x̄_fd; isapprox_kwargs...)
222-
_test_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
227+
# The main test of the actual derivative being correct:
228+
test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
229+
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
223230
end
224231
end
225232

226-
check_thunking_is_appropriate(x̄s_ad)
233+
check_thunking_is_appropriate(ad_cotangents)
227234
end # top-level testset
228235
end
229236

@@ -236,16 +243,6 @@ function check_thunking_is_appropriate(x̄s)
236243
end
237244
end
238245

239-
function _ensure_not_running_on_functor(f, name)
240-
# if x itself is a Type, then it is a constructor, thus not a functor.
241-
# This also catchs UnionAll constructors which have a `:var` and `:body` fields
242-
f isa Type && return nothing
243-
244-
if fieldcount(typeof(f)) > 0
245-
throw(ArgumentError("$name cannot be used on closures/functors (such as $f)"))
246-
end
247-
end
248-
249246
"""
250247
@maybe_inferred [Type] f(...)
251248

test/testers.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,38 @@ function finplace!(x; y=[1])
2121
return x
2222
end
2323

24+
struct Foo
25+
a::Float64
26+
end
27+
(f::Foo)(x) = return f.a + x
28+
Base.length(::Foo) = 1
29+
Base.iterate(f::Foo) = iterate(f.a)
30+
Base.iterate(f::Foo, state) = iterate(f.a, state)
31+
32+
# constructor
33+
function ChainRulesCore.rrule(::Type{Foo}, a)
34+
foo = Foo(a)
35+
function Foo_pullback(Δfoo)
36+
return NoTangent(), Δfoo.a
37+
end
38+
return foo, Foo_pullback
39+
end
40+
function ChainRulesCore.frule((_, Δa), ::Type{Foo}, a)
41+
return Foo(a), Foo(Δa)
42+
end
43+
44+
# functor
45+
function ChainRulesCore.rrule(f::Foo, x)
46+
y = f(x)
47+
function Foo_pullback(Δy)
48+
return Tangent{Foo}(;a=Δy), Δy
49+
end
50+
return y, Foo_pullback
51+
end
52+
function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
53+
return f(x), Δf.a + Δx
54+
end
55+
2456
@testset "testers.jl" begin
2557
@testset "test_scalar" begin
2658
@testset "Ensure correct rules succeed" begin
@@ -513,6 +545,26 @@ end
513545
end
514546
end
515547

548+
@testset "structs" begin
549+
@testset "constructor" begin
550+
test_frule(Foo, rand())
551+
test_rrule(Foo, rand())
552+
end
553+
554+
foo = Foo(rand())
555+
tfoo = Tangent{Foo}(;a=rand())
556+
@testset "functor" begin
557+
test_frule(foo, rand())
558+
test_rrule(foo, rand())
559+
test_scalar(foo, rand())
560+
561+
test_frule(foo Foo(rand()), rand())
562+
test_frule(foo tfoo, rand())
563+
test_rrule(foo Foo(rand()), rand())
564+
test_rrule(foo tfoo, rand())
565+
end
566+
end
567+
516568
@testset "Tuple primal that is not equal to differential backing" begin
517569
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/288
518570
forwards_trouble(x) = (1, 2.0 * x)

0 commit comments

Comments
 (0)