80
80
# Keyword Arguments
81
81
- `output_tangent` tangent to test accumulation of derivatives against
82
82
should be a differential for the output of `f`. Is set automatically if not provided.
83
+ - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84
+ transform the passed argument tangents into alternative tangents that should be tested.
85
+ Note that the alternative tangents are only tested for not erroring when passed to
86
+ frule. Testing for correctness using finite differencing can be done using a
87
+ separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88
+ `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
83
89
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
84
90
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
85
91
`frule`). Used for testing gradients from AD systems.
@@ -98,6 +104,7 @@ function test_frule(
98
104
f,
99
105
args... ;
100
106
output_tangent= Auto (),
107
+ tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
101
108
fdm= _fdm,
102
109
frule_f= ChainRulesCore. frule,
103
110
check_inferred:: Bool = true ,
@@ -122,7 +129,7 @@ function test_frule(
122
129
_test_inferred (frule_f, deepcopy (config), deepcopy (tangents), deepcopy (primals)... ; deepcopy (fkwargs)... )
123
130
end
124
131
125
- res = frule_f ( deepcopy ( config), deepcopy ( tangents), deepcopy ( primals) ... ; deepcopy (fkwargs) ... )
132
+ res = call_on_copy (frule_f, config, tangents, primals... )
126
133
res === nothing && throw (MethodError (frule_f, typeof (primals)))
127
134
@test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
128
135
Ω_ad, dΩ_ad = res
@@ -144,10 +151,26 @@ function test_frule(
144
151
test_approx (dΩ_ad, dΩ_fd; isapprox_kwargs... )
145
152
146
153
acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
147
- _test_add!!_behaviour (acc, dΩ_ad; rtol= rtol, atol= atol, kwargs... )
154
+ _test_add!!_behaviour (acc, dΩ_ad; isapprox_kwargs... )
155
+
156
+ # test that rules work for other tangents
157
+ _test_frule_alt_tangents (
158
+ call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
159
+ isapprox_kwargs...
160
+ )
148
161
end # top-level testset
149
162
end
150
163
164
+ function _test_frule_alt_tangents (
165
+ call, frule_f, config, tangent_transforms, tangents, primals, acc;
166
+ isapprox_kwargs...
167
+ )
168
+ @testset " ȧrgs = $(tsf .(tangents)) " for tsf in tangent_transforms
169
+ _, dΩ = call (frule_f, config, tsf .(tangents), primals... )
170
+ _test_add!!_behaviour (acc, dΩ; isapprox_kwargs... )
171
+ end
172
+ end
173
+
151
174
"""
152
175
test_rrule([config::RuleConfig,] f, args...; kwargs...)
153
176
162
185
# Keyword Arguments
163
186
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
164
187
should be a differential for the output of `f`. Is set automatically if not provided.
188
+ - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
189
+ transform the passed `output_tangent` into alternative tangents that should be tested.
190
+ Note that the alternative tangents are only tested for not erroring when passed to
191
+ rrule. Testing for correctness using finite differencing can be done using a
192
+ separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
193
+ `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
165
194
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
166
195
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
167
196
Used for testing gradients from AD systems.
@@ -180,6 +209,7 @@ function test_rrule(
180
209
f,
181
210
args... ;
182
211
output_tangent= Auto (),
212
+ tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
183
213
fdm= _fdm,
184
214
rrule_f= ChainRulesCore. rrule,
185
215
check_inferred:: Bool = true ,
@@ -249,9 +279,24 @@ function test_rrule(
249
279
_test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
250
280
end
251
281
end
282
+
283
+ # test other tangents don't error when passed to the pullback
284
+ _test_rrule_alt_tangents (pullback, tangent_transforms, ȳ, accum_cotangents)
252
285
end # top-level testset
253
286
end
254
287
288
+ function _test_rrule_alt_tangents (
289
+ pullback, tangent_transforms, ȳ, accum_cotangents;
290
+ isapprox_kwargs...
291
+ )
292
+ @testset " ȳ = $(tsf (ȳ)) " for tsf in tangent_transforms
293
+ ad_cotangents = pullback (tsf (ȳ))
294
+ for (accum_cotangent, ad_cotangent) in zip (accum_cotangents, ad_cotangents)
295
+ _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
296
+ end
297
+ end
298
+ end
299
+
255
300
"""
256
301
@maybe_inferred [Type] f(...)
257
302
0 commit comments