@@ -18,7 +18,6 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
18
18
isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
19
19
20
20
@testset " test_scalar: $f at $z " begin
21
- _ensure_not_running_on_functor (f, " test_scalar" )
22
21
# z = x + im * y
23
22
# Ω = u(x, y) + im * v(x, y)
24
23
Ω = f (z; fkwargs... )
@@ -30,8 +29,9 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
30
29
test_frule (f, z ⊢ Δx; rule_test_kwargs... )
31
30
if z isa Complex
32
31
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
33
- _, real_tangent = frule ((ZeroTangent (), real (Δx)), f, z; fkwargs... )
34
- _, embedded_tangent = frule ((ZeroTangent (), Δx), f, z; fkwargs... )
32
+ ḟ = rand_tangent (f)
33
+ _, real_tangent = frule ((ḟ, real (Δx)), f, z; fkwargs... )
34
+ _, embedded_tangent = frule ((ḟ, Δx), f, z; fkwargs... )
35
35
test_approx (real_tangent, embedded_tangent; isapprox_kwargs... )
36
36
end
37
37
end
70
70
test_frule(f, args..; kwargs...)
71
71
72
72
# Arguments
73
- - `f`: Function for which the `frule` should be tested.
73
+ - `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
74
74
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
75
75
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
76
76
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
@@ -99,25 +99,29 @@ function test_frule(
99
99
# To simplify some of the calls we make later lets group the kwargs for reuse
100
100
isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
101
101
102
+ # and define a helper closure
103
+ call_on_copy (f, xs... ) = deepcopy (f)(deepcopy (xs)... ; deepcopy (fkwargs)... )
104
+
102
105
@testset " test_frule: $f on $(_string_typeof (args)) " begin
103
- _ensure_not_running_on_functor (f, " test_frule" )
104
106
105
- xẋs = auto_primal_and_tangent .(args)
106
- xs = primal .(xẋs)
107
- ẋs = tangent .(xẋs)
108
- if check_inferred && _is_inferrable (f, deepcopy (xs)... ; deepcopy (fkwargs)... )
109
- _test_inferred (frule, (NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
107
+ primals_and_tangents = auto_primal_and_tangent .((f, args... ))
108
+ primals = primal .(primals_and_tangents)
109
+ tangents = tangent .(primals_and_tangents)
110
+
111
+ if check_inferred && _is_inferrable (deepcopy (primals)... ; deepcopy (fkwargs)... )
112
+ _test_inferred (frule, deepcopy (tangents), deepcopy (primals)... ; deepcopy (fkwargs)... )
110
113
end
111
- res = frule ((NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
112
- res === nothing && throw (MethodError (frule, typeof ((f, xs... ))))
114
+
115
+ res = frule (deepcopy (tangents), deepcopy (primals)... ; deepcopy (fkwargs)... )
116
+ res === nothing && throw (MethodError (frule, typeof (primals)))
113
117
@test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
114
118
Ω_ad, dΩ_ad = res
115
- Ω = f ( deepcopy (xs) ... ; deepcopy (fkwargs) ... )
119
+ Ω = call_on_copy (primals ... )
116
120
test_approx (Ω_ad, Ω; isapprox_kwargs... )
117
121
118
122
# TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
119
- ẋs_is_ignored = isa .(ẋs , Union{Nothing,NoTangent})
120
- if any (ẋs .== nothing )
123
+ is_ignored = isa .(tangents , Union{Nothing,NoTangent})
124
+ if any (tangents .== nothing )
121
125
Base. depwarn (
122
126
" test_frule(f, k ⊢ nothing) is deprecated, use " *
123
127
" test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
@@ -126,7 +130,7 @@ function test_frule(
126
130
end
127
131
128
132
# Correctness testing via finite differencing.
129
- dΩ_fd = _make_jvp_call (fdm, (xs ... ) -> f ( deepcopy (xs) ... ; deepcopy (fkwargs) ... ) , Ω, xs, ẋs, ẋs_is_ignored )
133
+ dΩ_fd = _make_jvp_call (fdm, call_on_copy , Ω, primals, tangents, is_ignored )
130
134
test_approx (dΩ_ad, dΩ_fd; isapprox_kwargs... )
131
135
132
136
acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
@@ -138,14 +142,14 @@ end
138
142
test_rrule(f, args...; kwargs...)
139
143
140
144
# Arguments
141
- - `f`: Function to which rule should be applied.
142
- - `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ `
145
+ - `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
146
+ - `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄ `
143
147
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144
148
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145
149
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
146
150
147
151
# Keyword Arguments
148
- - `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
152
+ - `output_tangent` the seed to propagate backward for testing (technically a cotangent).
149
153
should be a differential for the output of `f`. Is set automatically if not provided.
150
154
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
151
155
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
@@ -167,63 +171,66 @@ function test_rrule(
167
171
# To simplify some of the calls we make later lets group the kwargs for reuse
168
172
isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
169
173
174
+ # and define helper closure over fkwargs
175
+ call (f, xs... ) = f (xs... ; fkwargs... )
176
+
170
177
@testset " test_rrule: $f on $(_string_typeof (args)) " begin
171
- _ensure_not_running_on_functor (f, " test_rrule" )
172
178
173
179
# Check correctness of evaluation.
174
- xx̄s = auto_primal_and_tangent .(args)
175
- xs = primal .(xx̄s)
176
- accumulated_x̄ = tangent .(xx̄s)
177
- if check_inferred && _is_inferrable (f, xs... ; fkwargs... )
178
- _test_inferred (rrule, f, xs... ; fkwargs... )
180
+ primals_and_tangents = auto_primal_and_tangent .((f, args... ))
181
+ primals = primal .(primals_and_tangents)
182
+ accum_cotangents = tangent .(primals_and_tangents)
183
+
184
+ if check_inferred && _is_inferrable (primals... ; fkwargs... )
185
+ _test_inferred (rrule, primals... ; fkwargs... )
179
186
end
180
- res = rrule (f, xs ... ; fkwargs... )
181
- res === nothing && throw (MethodError (rrule, typeof ((f, xs ... ))))
187
+ res = rrule (primals ... ; fkwargs... )
188
+ res === nothing && throw (MethodError (rrule, typeof ((primals ... ))))
182
189
y_ad, pullback = res
183
- y = f (xs ... ; fkwargs ... )
190
+ y = call (primals ... )
184
191
test_approx (y_ad, y; isapprox_kwargs... ) # make sure primal is correct
185
192
186
193
ȳ = output_tangent isa Auto ? rand_tangent (y) : output_tangent
187
194
188
195
check_inferred && _test_inferred (pullback, ȳ)
189
- ∂s = pullback (ȳ)
190
- ∂s isa Tuple || error (" The pullback must return (∂self, ∂args...), not $∂s ." )
191
- ∂self = ∂s[1 ]
192
- x̄s_ad = ∂s[2 : end ]
193
- @test ∂self === NoTangent () # No internal fields
194
- msg = " The pullback should return 1 cotangent for each primal input."
195
- @test_msg msg length (x̄s_ad) == length (args)
196
+ ad_cotangents = pullback (ȳ)
197
+ ad_cotangents isa Tuple || error (" The pullback must return (∂self, ∂args...), not $∂s ." )
198
+ msg = " The pullback should return 1 cotangent for the primal and each primal input."
199
+ @test_msg msg length (ad_cotangents) == 1 + length (args)
196
200
197
201
# Correctness testing via finite differencing.
198
202
# TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
199
- x̄s_is_dne = isa .(accumulated_x̄ , Union{Nothing,NoTangent})
200
- if any (accumulated_x̄ .== nothing )
203
+ is_ignored = isa .(accum_cotangents , Union{Nothing, NoTangent})
204
+ if any (accum_cotangents .== nothing )
201
205
Base. depwarn (
202
206
" test_rrule(f, k ⊢ nothing) is deprecated, use " *
203
207
" test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
204
208
:test_rrule ,
205
209
)
206
210
end
207
211
208
- x̄s_fd = _make_j′vp_call (fdm, (xs... ) -> f (xs... ; fkwargs... ), ȳ, xs, x̄s_is_dne)
209
- for (accumulated_x̄, x̄_ad, x̄_fd) in zip (accumulated_x̄, x̄s_ad, x̄s_fd)
210
- if accumulated_x̄ isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
211
- @assert x̄_fd === nothing # this is how `_make_j′vp_call` works
212
- x̄_ad isa ZeroTangent && error (
213
- " The pullback in the rrule for $f function should use NoTangent()" *
212
+ fd_cotangents = _make_j′vp_call (fdm, call, ȳ, primals, is_ignored)
213
+
214
+ for (accum_cotangent, ad_cotangent, fd_cotangent) in zip (
215
+ accum_cotangents, ad_cotangents, fd_cotangents
216
+ )
217
+ if accum_cotangent isa Union{Nothing,NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
218
+ @assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
219
+ ad_cotangent isa ZeroTangent && error (
220
+ " The pullback in the rrule should use NoTangent()" *
214
221
" rather than ZeroTangent() for non-perturbable arguments." ,
215
222
)
216
- @test x̄_ad isa NoTangent # we said it wasn't differentiable.
223
+ @test ad_cotangent isa NoTangent # we said it wasn't differentiable.
217
224
else
218
- x̄_ad isa AbstractThunk && check_inferred && _test_inferred (unthunk, x̄_ad )
225
+ ad_cotangent isa AbstractThunk && check_inferred && _test_inferred (unthunk, ad_cotangent )
219
226
220
- # The main test of the actual deriviative being correct:
221
- test_approx (x̄_ad, x̄_fd ; isapprox_kwargs... )
222
- _test_add!!_behaviour (accumulated_x̄, x̄_ad ; isapprox_kwargs... )
227
+ # The main test of the actual derivative being correct:
228
+ test_approx (ad_cotangent, fd_cotangent ; isapprox_kwargs... )
229
+ _test_add!!_behaviour (accum_cotangent, ad_cotangent ; isapprox_kwargs... )
223
230
end
224
231
end
225
232
226
- check_thunking_is_appropriate (x̄s_ad )
233
+ check_thunking_is_appropriate (ad_cotangents )
227
234
end # top-level testset
228
235
end
229
236
@@ -236,16 +243,6 @@ function check_thunking_is_appropriate(x̄s)
236
243
end
237
244
end
238
245
239
- function _ensure_not_running_on_functor (f, name)
240
- # if x itself is a Type, then it is a constructor, thus not a functor.
241
- # This also catchs UnionAll constructors which have a `:var` and `:body` fields
242
- f isa Type && return nothing
243
-
244
- if fieldcount (typeof (f)) > 0
245
- throw (ArgumentError (" $name cannot be used on closures/functors (such as $f )" ))
246
- end
247
- end
248
-
249
246
"""
250
247
@maybe_inferred [Type] f(...)
251
248
0 commit comments