Skip to content

Commit fc7408c

Browse files
authored
Merge pull request #210 from JuliaDiff/mz/rrule_via_ad
Improve docs on testing AD gradients
2 parents 20db52a + 06db673 commit fc7408c

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

docs/Manifest.toml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1616

1717
[[ChainRulesCore]]
1818
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
19-
git-tree-sha1 = "f53ca8d41e4753c41cdafa6ec5f7ce914b34be54"
19+
git-tree-sha1 = "bdc0937269321858ab2a4f288486cb258b9a0af7"
2020
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
21-
version = "0.10.13"
21+
version = "1.3.0"
2222

2323
[[ChainRulesTestUtils]]
2424
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2525
path = ".."
2626
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
27-
version = "1.0.0-DEV"
27+
version = "1.2.1"
2828

2929
[[Compat]]
3030
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
31-
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
31+
git-tree-sha1 = "727e463cfebd0c7b999bbf3e9e7e16f254b94193"
3232
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
33-
version = "3.31.0"
33+
version = "3.34.0"
3434

3535
[[Dates]]
3636
deps = ["Printf"]
@@ -52,19 +52,19 @@ version = "0.8.5"
5252

5353
[[Documenter]]
5454
deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
55-
git-tree-sha1 = "95265abf7d7bf06dfdb8d58525a23ea5fb0bdeee"
55+
git-tree-sha1 = "350dced36c11f794c6c4da5dc6493ec894e50c16"
5656
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
57-
version = "0.27.4"
57+
version = "0.27.5"
5858

5959
[[Downloads]]
6060
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
6161
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
6262

6363
[[FiniteDifferences]]
6464
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
65-
git-tree-sha1 = "18761c465ef2e87d9091c0fefb61f70d532d4cc0"
65+
git-tree-sha1 = "9a586f04a21e6945f4cbee0d0fb6aebd7b86aa8f"
6666
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
67-
version = "0.12.16"
67+
version = "0.12.18"
6868

6969
[[IOCapture]]
7070
deps = ["Logging", "Random"]
@@ -78,9 +78,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7878

7979
[[JSON]]
8080
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
81-
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
81+
git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37"
8282
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
83-
version = "0.21.1"
83+
version = "0.21.2"
8484

8585
[[LibCURL]]
8686
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
@@ -127,9 +127,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
127127

128128
[[Parsers]]
129129
deps = ["Dates"]
130-
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
130+
git-tree-sha1 = "438d35d2d95ae2c5e8780b330592b6de8494e779"
131131
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
132-
version = "1.1.0"
132+
version = "2.0.3"
133133

134134
[[Pkg]]
135135
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
@@ -172,9 +172,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
172172

173173
[[StaticArrays]]
174174
deps = ["LinearAlgebra", "Random", "Statistics"]
175-
git-tree-sha1 = "1b9a0f17ee0adde9e538227de093467348992397"
175+
git-tree-sha1 = "3240808c6d463ac46f1c1cd7638375cd22abbccb"
176176
uuid = "90137ffa-7385-5640-81b9-e52037218182"
177-
version = "1.2.7"
177+
version = "1.2.12"
178178

179179
[[Statistics]]
180180
deps = ["LinearAlgebra", "SparseArrays"]

docs/src/index.md

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,31 @@ end
183183

184184
that we do not know an `rrule` for, and we want to check whether the gradients provided by the AD system are correct.
185185

186-
Firstly, we need to define an `rrule`-like function which wraps the gradients computed by AD.
186+
To test gradients computed by the AD system you need to provide a `rrule_f` function that acts like calling `rrule` but use AD rather than a defined rule.
187+
This has the exact same semantics as is required to overload `ChainRulesCore.rrule_via_ad`, thus almost all systems doing so should just overload that, and pass in that and the config, and then trigger `test_rrule(MyADConfig, f, xs; rrule_f = ChainRulesCore.rrule_via_ad)`.
188+
See more info on `rrule_via_ad` and the rule configs in the [ChainRules documentation](https://juliadiff.org/ChainRulesCore.jl/stable/config.html).
189+
For some AD systems (e.g. Zygote) `rrule_via_ad` already exists.
190+
If it does not exist, see [How to write `rrule_via_ad` function](#How-to-write-rrule_via_ad-function) section below.
191+
192+
We use the `test_rrule` function to test the gradients using the config used by the AD system
193+
```julia
194+
config = MyAD.CustomRuleConfig()
195+
test_rrule(config, complicated, 2.3, 6.1; rrule_f=rrule_via_ad)
196+
```
197+
by providing the rule config and specifying the `rrule_via_ad` as the `rrule_f` keyword argument.
198+
199+
200+
### How to write `rrule_via_ad` function
201+
202+
`rrule_via_ad` will use the AD system to compute gradients and will package them in the `rrule`-like API.
187203

188204
Let's say the AD package uses some custom differential types and does not provide a gradient w.r.t. the function itself.
189205
In order to make the pullback compatible with the `rrule` API we need to add a `NoTangent()` to represent the differential w.r.t. the function itself.
190206
We also need to transform the `ChainRules` differential types to the custom types (`cr2custom`) before feeding the `Δ` to the AD-generated pullback, and back to `ChainRules` differential types when returning from the `rrule` (`custom2cr`).
191207

192208
```julia
193-
function ad_rrule(f::Function, args...)
194-
y, ad_pullback = ADSystem.pullback(f, args...)
209+
function rrule_via_ad(config::MyAD.CustomRuleConfig, f::Function, args...)
210+
y, ad_pullback = MyAD.pullback(f, args...)
195211
function rrulelike_pullback(Δ)
196212
diffs = custom2cr(ad_pullback(cr2custom(Δ)))
197213
return NoTangent(), diffs...
@@ -203,12 +219,6 @@ end
203219
custom2cr(differential) = ...
204220
cr2custom(differential) = ...
205221
```
206-
Secondly, we use the `test_rrule` function to test the gradients using the config used by the AD system
207-
```julia
208-
config = MyAD.CustomRuleConfig()
209-
test_rrule(config, complicated, 2.3, 6.1; rrule_f=ad_rrule)
210-
```
211-
by specifying the `ad_rrule` as the `rrule_f` keyword argument.
212222

213223
## Custom finite differencing
214224

0 commit comments

Comments
 (0)