80
80
# Keyword Arguments
81
81
- `output_tangent` tangent to test accumulation of derivatives against
82
82
should be a differential for the output of `f`. Is set automatically if not provided.
83
- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84
- transform the passed argument tangents into alternative tangents that should be tested.
85
- Note that the alternative tangents are only tested for not erroring when passed to
86
- frule. Testing for correctness using finite differencing can be done using a
87
- separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88
- `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
89
83
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
90
84
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
91
85
`frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
104
98
f,
105
99
args... ;
106
100
output_tangent= Auto (),
107
- tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
108
101
fdm= _fdm,
109
102
frule_f= ChainRulesCore. frule,
110
103
check_inferred:: Bool = true ,
@@ -136,41 +129,16 @@ function test_frule(
136
129
Ω = call_on_copy (primals... )
137
130
test_approx (Ω_ad, Ω; isapprox_kwargs... )
138
131
139
- # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
140
- is_ignored = isa .(tangents, Union{Nothing,NoTangent})
141
- if any (tangents .== nothing )
142
- Base. depwarn (
143
- " test_frule(f, k ⊢ nothing) is deprecated, use " *
144
- " test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
145
- :test_frule ,
146
- )
147
- end
148
-
149
132
# Correctness testing via finite differencing.
133
+ is_ignored = isa .(tangents, NoTangent)
150
134
dΩ_fd = _make_jvp_call (fdm, call_on_copy, Ω, primals, tangents, is_ignored)
151
135
test_approx (dΩ_ad, dΩ_fd; isapprox_kwargs... )
152
136
153
137
acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
154
138
_test_add!!_behaviour (acc, dΩ_ad; isapprox_kwargs... )
155
-
156
- # test that rules work for other tangents
157
- _test_frule_alt_tangents (
158
- call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
159
- isapprox_kwargs...
160
- )
161
139
end # top-level testset
162
140
end
163
141
164
- function _test_frule_alt_tangents (
165
- call, frule_f, config, tangent_transforms, tangents, primals, acc;
166
- isapprox_kwargs...
167
- )
168
- @testset " ȧrgs = $(_string_typeof (tsf .(tangents))) " for tsf in tangent_transforms
169
- _, dΩ = call (frule_f, config, tsf .(tangents), primals... )
170
- _test_add!!_behaviour (acc, dΩ; isapprox_kwargs... )
171
- end
172
- end
173
-
174
142
"""
175
143
test_rrule([config::RuleConfig,] f, args...; kwargs...)
176
144
185
153
# Keyword Arguments
186
154
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
187
155
should be a differential for the output of `f`. Is set automatically if not provided.
188
- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
189
- transform the passed `output_tangent` into alternative tangents that should be tested.
190
- Note that the alternative tangents are only tested for not erroring when passed to
191
- rrule. Testing for correctness using finite differencing can be done using a
192
- separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
193
- `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156
+ - `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157
+ output tangent to the pullback returns the same result.
194
158
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
195
159
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
196
160
Used for testing gradients from AD systems.
@@ -209,7 +173,7 @@ function test_rrule(
209
173
f,
210
174
args... ;
211
175
output_tangent= Auto (),
212
- tangent_transforms = TRANSFORMS_TO_ALT_TANGENTS ,
176
+ check_thunked_output_tangent = true ,
213
177
fdm= _fdm,
214
178
rrule_f= ChainRulesCore. rrule,
215
179
check_inferred:: Bool = true ,
@@ -254,22 +218,13 @@ function test_rrule(
254
218
)
255
219
256
220
# Correctness testing via finite differencing.
257
- # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
258
- is_ignored = isa .(accum_cotangents, Union{Nothing, NoTangent})
259
- if any (accum_cotangents .== nothing )
260
- Base. depwarn (
261
- " test_rrule(f, k ⊢ nothing) is deprecated, use " *
262
- " test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
263
- :test_rrule ,
264
- )
265
- end
266
-
221
+ is_ignored = isa .(accum_cotangents, NoTangent)
267
222
fd_cotangents = _make_j′vp_call (fdm, call, ȳ, primals, is_ignored)
268
223
269
224
for (accum_cotangent, ad_cotangent, fd_cotangent) in zip (
270
225
accum_cotangents, ad_cotangents, fd_cotangents
271
226
)
272
- if accum_cotangent isa Union{Nothing, NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
227
+ if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
273
228
@assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
274
229
ad_cotangent isa ZeroTangent && error (
275
230
" The pullback in the rrule should use NoTangent()" *
@@ -285,21 +240,10 @@ function test_rrule(
285
240
end
286
241
end
287
242
288
- # test other tangents don't error when passed to the pullback
289
- _test_rrule_alt_tangents (pullback, tangent_transforms, ȳ, accum_cotangents)
290
- end # top-level testset
291
- end
292
-
293
- function _test_rrule_alt_tangents (
294
- pullback, tangent_transforms, ȳ, accum_cotangents;
295
- isapprox_kwargs...
296
- )
297
- @testset " ȳ = $(_string_typeof (tsf (ȳ))) " for tsf in tangent_transforms
298
- ad_cotangents = pullback (tsf (ȳ))
299
- for (accum_cotangent, ad_cotangent) in zip (accum_cotangents, ad_cotangents)
300
- _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
243
+ if check_thunked_output_tangent
244
+ test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk:" )
301
245
end
302
- end
246
+ end # top-level testset
303
247
end
304
248
305
249
"""
0 commit comments