80
80
# Keyword Arguments
81
81
- `output_tangent` tangent to test accumulation of derivatives against
82
82
should be a differential for the output of `f`. Is set automatically if not provided.
83
- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84
- transform the passed argument tangents into alternative tangents that should be tested.
85
- Note that the alternative tangents are only tested for not erroring when passed to
86
- frule. Testing for correctness using finite differencing can be done using a
87
- separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88
- `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
89
83
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
90
84
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
91
85
`frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
104
98
f,
105
99
args... ;
106
100
output_tangent= Auto (),
107
- tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
108
101
fdm= _fdm,
109
102
frule_f= ChainRulesCore. frule,
110
103
check_inferred:: Bool = true ,
@@ -143,25 +136,9 @@ function test_frule(
143
136
144
137
acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
145
138
_test_add!!_behaviour (acc, dΩ_ad; isapprox_kwargs... )
146
-
147
- # test that rules work for other tangents
148
- _test_frule_alt_tangents (
149
- call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
150
- isapprox_kwargs...
151
- )
152
139
end # top-level testset
153
140
end
154
141
155
- function _test_frule_alt_tangents (
156
- call, frule_f, config, tangent_transforms, tangents, primals, acc;
157
- isapprox_kwargs...
158
- )
159
- @testset " ȧrgs = $(_string_typeof (tsf .(tangents))) " for tsf in tangent_transforms
160
- _, dΩ = call (frule_f, config, tsf .(tangents), primals... )
161
- _test_add!!_behaviour (acc, dΩ; isapprox_kwargs... )
162
- end
163
- end
164
-
165
142
"""
166
143
test_rrule([config::RuleConfig,] f, args...; kwargs...)
167
144
176
153
# Keyword Arguments
177
154
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
178
155
should be a differential for the output of `f`. Is set automatically if not provided.
179
- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
180
- transform the passed `output_tangent` into alternative tangents that should be tested.
181
- Note that the alternative tangents are only tested for not erroring when passed to
182
- rrule. Testing for correctness using finite differencing can be done using a
183
- separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
184
- `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156
+ - `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157
+ output tangent to the pullback returns the same result.
185
158
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
186
159
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
187
160
Used for testing gradients from AD systems.
@@ -200,7 +173,7 @@ function test_rrule(
200
173
f,
201
174
args... ;
202
175
output_tangent= Auto (),
203
- tangent_transforms = TRANSFORMS_TO_ALT_TANGENTS ,
176
+ check_thunked_output_tangent = true ,
204
177
fdm= _fdm,
205
178
rrule_f= ChainRulesCore. rrule,
206
179
check_inferred:: Bool = true ,
@@ -267,21 +240,10 @@ function test_rrule(
267
240
end
268
241
end
269
242
270
- # test other tangents don't error when passed to the pullback
271
- _test_rrule_alt_tangents (pullback, tangent_transforms, ȳ, accum_cotangents)
272
- end # top-level testset
273
- end
274
-
275
- function _test_rrule_alt_tangents (
276
- pullback, tangent_transforms, ȳ, accum_cotangents;
277
- isapprox_kwargs...
278
- )
279
- @testset " ȳ = $(_string_typeof (tsf (ȳ))) " for tsf in tangent_transforms
280
- ad_cotangents = pullback (tsf (ȳ))
281
- for (accum_cotangent, ad_cotangent) in zip (accum_cotangents, ad_cotangents)
282
- _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
243
+ if check_thunked_output_tangent
244
+ test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk" )
283
245
end
284
- end
246
+ end # top-level testset
285
247
end
286
248
287
249
"""
0 commit comments