You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/index.md
+19-9Lines changed: 19 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -183,15 +183,31 @@ end
183
183
184
184
that we do not know an `rrule` for, and we want to check whether the gradients provided by the AD system are correct.
185
185
186
-
Firstly, we need to define an `rrule`-like function which wraps the gradients computed by AD.
186
+
To test gradients computed by the AD system you need to provide a `rrule_f` function that acts like calling `rrule` but use AD rather than a defined rule.
187
+
This has the exact same semantics as is required to overload `ChainRulesCore.rrule_via_ad`, thus almost all systems doing so should just overload that, and pass in that and the config, and then trigger `test_rrule(MyADConfig, f, xs; rrule_f = ChainRulesCore.rrule_via_ad)`.
188
+
See more info on `rrule_via_ad` and the rule configs in the [ChainRules documentation](https://juliadiff.org/ChainRulesCore.jl/stable/config.html).
189
+
For some AD systems (e.g. Zygote) `rrule_via_ad` already exists.
190
+
If it does not exist, see [How to write `rrule_via_ad` function](#How-to-write-rrule_via_ad-function) section below.
191
+
192
+
We use the `test_rrule` function to test the gradients using the config used by the AD system
by providing the rule config and specifying the `rrule_via_ad` as the `rrule_f` keyword argument.
198
+
199
+
200
+
### How to write `rrule_via_ad` function
201
+
202
+
`rrule_via_ad` will use the AD system to compute gradients and will package them in the `rrule`-like API.
187
203
188
204
Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself.
189
205
In order to make the pullback compatible with the `rrule` API we need to add a `NoTangent()` to represent the differential w.r.t. the function itself.
190
206
We also need to transform the `ChainRules` differential types to the custom types (`cr2custom`) before feeding the `Δ` to the AD-generated pullback, and back to `ChainRules` differential types when returning from the `rrule` (`custom2cr`).
0 commit comments