Skip to content

Commit a2b99e6

Browse files
Fix doctests; Organise and improve some docs (#131)
* Remove docs dependency on ChainRules * Organise API docs page * Fix doctests * Improve docs on unthunking * Build docs on latest stable Julia * Add comment re rules in doctest setup Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 03474ab commit a2b99e6

File tree

9 files changed

+106
-84
lines changed

9 files changed

+106
-84
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- julia: 1.4
1616
include:
1717
- stage: "Documentation"
18-
julia: 1.0
18+
julia: 1
1919
os: linux
2020
script:
2121
- julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'

docs/Manifest.toml

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
13
[[Base64]]
24
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
35

4-
[[ChainRules]]
5-
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
6-
git-tree-sha1 = "906cb2ae273ddbc559490117faa7abd36c98f51a"
7-
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
8-
version = "0.3.2"
9-
106
[[ChainRulesCore]]
117
deps = ["MuladdMacro"]
128
path = ".."
139
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
14-
version = "0.6.1"
10+
version = "0.7.0"
1511

1612
[[Dates]]
1713
deps = ["Printf"]
1814
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
1915

2016
[[Distributed]]
21-
deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"]
17+
deps = ["Random", "Serialization", "Sockets"]
2218
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
2319

2420
[[DocStringExtensions]]
@@ -34,7 +30,7 @@ uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3430
version = "0.23.4"
3531

3632
[[InteractiveUtils]]
37-
deps = ["LinearAlgebra", "Markdown"]
33+
deps = ["Markdown"]
3834
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
3935

4036
[[JSON]]
@@ -49,10 +45,6 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
4945
[[Libdl]]
5046
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
5147

52-
[[LinearAlgebra]]
53-
deps = ["Libdl"]
54-
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
55-
5648
[[Logging]]
5749
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
5850

@@ -70,12 +62,12 @@ version = "0.2.2"
7062

7163
[[Parsers]]
7264
deps = ["Dates", "Test"]
73-
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
65+
git-tree-sha1 = "d112c19ccca00924d5d3a38b11ae2b4b268dda39"
7466
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
75-
version = "0.3.10"
67+
version = "0.3.11"
7668

7769
[[Pkg]]
78-
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
70+
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
7971
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8072

8173
[[Printf]]
@@ -90,18 +82,6 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
9082
deps = ["Serialization"]
9183
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9284

93-
[[Reexport]]
94-
deps = ["Pkg"]
95-
git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
96-
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
97-
version = "0.2.0"
98-
99-
[[Requires]]
100-
deps = ["UUIDs"]
101-
git-tree-sha1 = "999513b7dea8ac17359ed50ae8ea089e4464e35e"
102-
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
103-
version = "1.0.0"
104-
10585
[[SHA]]
10686
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
10787

@@ -111,20 +91,12 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
11191
[[Sockets]]
11292
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
11393

114-
[[SparseArrays]]
115-
deps = ["LinearAlgebra", "Random"]
116-
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
117-
118-
[[Statistics]]
119-
deps = ["LinearAlgebra", "SparseArrays"]
120-
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
121-
12294
[[Test]]
12395
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
12496
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
12597

12698
[[UUIDs]]
127-
deps = ["Random"]
99+
deps = ["Random", "SHA"]
128100
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
129101

130102
[[Unicode]]

docs/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
32
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
43
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
54

docs/make.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
1-
using ChainRules
21
using ChainRulesCore
32
using Documenter
43

54
@show ENV
65

6+
DocMeta.setdocmeta!(
7+
ChainRulesCore,
8+
:DocTestSetup,
9+
quote
10+
using Random
11+
Random.seed!(0) # frule doctest shows output
12+
13+
using ChainRulesCore
14+
# These rules are all actually defined in ChainRules.jl, but we redefine them here to
15+
# avoid the dependency.
16+
@scalar_rule(sin(x), cos(x)) # frule and rrule doctest
17+
@scalar_rule(sincos(x), @setup((sinx, cosx) = Ω), cosx, -sinx) # frule doctest
18+
@scalar_rule(hypot(x::Real, y::Real), (x / Ω, y / Ω)) # rrule doctest
19+
end
20+
)
21+
722
makedocs(
8-
modules=[ChainRules, ChainRulesCore],
23+
modules=[ChainRulesCore],
924
format=Documenter.HTML(prettyurls=false, assets=["assets/chainrules.css"]),
1025
sitename="ChainRules",
1126
authors="Jarrett Revels and other contributors",

docs/src/api.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
# API Documentation
22

3+
## Rules
34
```@autodocs
45
Modules = [ChainRulesCore]
6+
Pages = ["rules.jl"]
7+
Private = false
8+
```
9+
10+
## Rule Definition Tools
11+
```@autodocs
12+
Modules = [ChainRulesCore]
13+
Pages = ["rule_definition_tools.jl"]
14+
Private = false
15+
```
16+
17+
## Differentials
18+
```@autodocs
19+
Modules = [ChainRulesCore]
20+
Pages = [
21+
"differentials/abstract_zero.jl",
22+
"differentials/one.jl",
23+
"differentials/composite.jl",
24+
"differentials/thunks.jl",
25+
"differentials/abstract_differential.jl",
26+
]
27+
Private = false
28+
```
29+
30+
## Internal
31+
```@docs
32+
ChainRulesCore.AbstractDifferential
533
```

docs/src/index.md

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
```@meta
2-
DocTestSetup = :(using ChainRulesCore, ChainRules)
3-
```
4-
51
# ChainRules
62

73
[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives.
@@ -288,62 +284,70 @@ This was once how all neural network code worked.
288284

289285
Using ChainRules directly also helps get a feel for it.
290286

291-
```jldoctest
292-
using ChainRules
287+
```jldoctest index; output=false
288+
using ChainRulesCore
293289
294290
function foo(x)
295291
a = sin(x)
296-
b = 2a
292+
b = 0.2 + a
297293
c = asin(b)
298294
return c
299295
end
300296
297+
# Define rules (alternatively get them for free via `using ChainRules`)
298+
@scalar_rule(sin(x), cos(x))
299+
@scalar_rule(+(x, y), (One(), One()))
300+
@scalar_rule(asin(x), inv(sqrt(1 - x^2)))
301+
# output
302+
303+
```
304+
```jldoctest index
301305
#### Find dfoo/dx via rrules
302306
#### First the forward pass, accumulating rules
303307
x = 3;
304308
a, a_pullback = rrule(sin, x);
305-
b, b_pullback = rrule(*, 2, a);
309+
b, b_pullback = rrule(+, 0.2, a);
306310
c, c_pullback = rrule(asin, b)
307311
308312
#### Then the backward pass calculating gradients
309313
c̄ = 1; # ∂c/∂c
310-
_, b̄ = c_pullback(extern(c̄)); # ∂c/∂b
311-
_, _, ā = b_pullback(extern(b̄)); # ∂c/∂a
312-
_, x̄ = a_pullback(extern(ā)); # ∂c/∂x = ∂f/∂x
313-
extern(x̄)
314+
_, b̄ = c_pullback(unthunk(c̄)); # ∂c/∂b
315+
_, _, ā = b_pullback(unthunk(b̄)); # ∂c/∂a
316+
_, x̄ = a_pullback(unthunk(ā)); # ∂c/∂x = ∂f/∂x
317+
unthunk(x̄)
314318
# output
315-
-2.0638950738662625
319+
-1.0531613736418153
316320
```
317-
```jldoctest
321+
```jldoctest index
318322
#### Find dfoo/dx via frules
319323
x = 3;
320324
ẋ = 1; # ∂x/∂x
321325
nofields = Zero(); # ∂self/∂self
322326
323327
a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x
324-
b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), *, 2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
328+
b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), +, 0.2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
325329
326330
c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
327331
unthunk(ċ)
328332
# output
329-
-2.0638950738662625
333+
-1.0531613736418153
330334
```
331335
```julia
332336
#### Find dfoo/dx via FiniteDifferences.jl
333337
using FiniteDifferences
334338
central_fdm(5, 1)(foo, x)
335339
# output
336-
-2.0638950738670734
340+
-1.0531613736418257
337341

338342
#### Find dfoo/dx via ForwardDiff.jl
339343
using ForwardDiff
340344
ForwardDiff.derivative(foo, x)
341345
# output
342-
-2.0638950738662625
346+
-1.0531613736418153
343347

344348
#### Find dfoo/dx via Zygote.jl
345349
using Zygote
346350
Zygote.gradient(foo, x)
347351
# output
348-
(-2.0638950738662625,)
352+
(-1.0531613736418153,)
349353
```

src/differentials/abstract_differential.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@ For two reasons:
5050
Where it is defined the operation of `extern` for a primal type `P` should be
5151
`extern(x) = zero(P) + x`.
5252
53-
Because of its limitations, `extern` should only really be used for testing.
54-
It can be useful, if you know what you are getting out, as it recursively removes thunks,
55-
and otherwise makes outputs more consistent with finite differencing.
56-
The more useful action in general is to call `+`, or in the case of thunks: [`unthunk`](@ref).
53+
!!! note
54+
Because of its limitations, `extern` should only really be used for testing.
55+
It can be useful, if you know what you are getting out, as it recursively removes
56+
thunks, and otherwise makes outputs more consistent with finite differencing.
5757
58-
Note that `extern` may return an alias (not necessarily a copy) to data
59-
wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
58+
The more useful action in general is to call `+`, or in the case of a [`Thunk`](@ref)
59+
to call [`unthunk`](@ref).
60+
61+
!!! warning
62+
`extern` may return an alias (not necessarily a copy) to data
63+
wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
6064
"""
6165
@inline extern(x) = x
6266

src/differentials/thunks.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,20 @@ It wraps a zero argument closure that when invoked returns a differential.
2424
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
2525
2626
Calling a thunk, calls the wrapped closure.
27-
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
28-
If you do not want that, then simply call the thunk
27+
If you are unsure if you have a `Thunk`, call [`unthunk`](@ref) which is a no-op when the
28+
argument is not a `Thunk`.
29+
If you need to unthunk recursively, call [`extern`](@ref), which also externs the differial
30+
that the closure returns.
2931
30-
```
32+
```jldoctest
3133
julia> t = @thunk(@thunk(3))
32-
Thunk(var"##7#9"())
34+
Thunk(var"#4#6"())
3335
3436
julia> extern(t)
3537
3
3638
3739
julia> t()
38-
Thunk(var"##8#10"())
40+
Thunk(var"#5#7"())
3941
4042
julia> t()()
4143
3
@@ -83,7 +85,7 @@ end
8385
On `AbstractThunk`s this removes 1 layer of thunking.
8486
On any other type, it is the identity operation.
8587
86-
In contrast to `extern` this is nonrecursive.
88+
In contrast to [`extern`](@ref) this is nonrecursive.
8789
"""
8890
@inline unthunk(x) = x
8991

src/rules.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ Examples:
1717
1818
unary input, unary output scalar function:
1919
20-
```jldoctest
21-
julia> dself = Zero()
22-
Zero()
20+
```jldoctest frule
21+
julia> dself = NO_FIELDS;
2322
24-
julia> x = rand();
23+
julia> x = rand()
24+
0.8236475079774124
2525
26-
julia> sinx, Δsinx = frule(sin, x, dself, 1)
27-
(0.35696518021277485, 0.9341176907197836)
26+
julia> sinx, Δsinx = frule((dself, 1), sin, x)
27+
(0.7336293678134624, 0.6795498147167869)
2828
2929
julia> sinx == sin(x)
3030
true
@@ -35,10 +35,8 @@ true
3535
3636
unary input, binary output scalar function:
3737
38-
```jldoctest
39-
julia> x = rand();
40-
41-
julia> sincosx, Δsincosx = frule(sincos, x, dself, 1);
38+
```jldoctest frule
39+
julia> sincosx, Δsincosx = frule((dself, 1), sincos, x);
4240
4341
julia> sincosx == sincos(x)
4442
true
@@ -69,7 +67,7 @@ Examples:
6967
7068
unary input, unary output scalar function:
7169
72-
```
70+
```jldoctest
7371
julia> x = rand();
7472
7573
julia> sinx, sin_pullback = rrule(sin, x);
@@ -83,7 +81,7 @@ true
8381
8482
binary input, unary output scalar function:
8583
86-
```
84+
```jldoctest
8785
julia> x, y = rand(2);
8886
8987
julia> hypotxy, hypot_pullback = rrule(hypot, x, y);

0 commit comments

Comments
 (0)