Skip to content

Commit fb34d39

Browse files
Fix and doctest example of using ChainRules directly (#120)
* Fix link to this package * Enable doctests * Require doctests to pass * Do not check for internal docstrings * Fix `frule` example * Doctest `rrule` example too * Make comments' naming convention consistent
1 parent c59ee78 commit fb34d39

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

docs/Manifest.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ version = "0.3.2"
99

1010
[[ChainRulesCore]]
1111
deps = ["MuladdMacro"]
12-
git-tree-sha1 = "2d67fd76f99ffba4059e55be324b24bf38582a38"
13-
repo-rev = "ox/movedoctsD2"
14-
repo-url = ".."
12+
path = ".."
1513
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1614
version = "0.6.1"
1715

docs/make.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Documenter
66

77
makedocs(
88
modules=[ChainRules, ChainRulesCore],
9-
format=Documenter.HTML(prettyurls=false, assets = ["assets/chainrules.css"]),
9+
format=Documenter.HTML(prettyurls=false, assets=["assets/chainrules.css"]),
1010
sitename="ChainRules",
1111
authors="Jarrett Revels and other contributors",
1212
pages=[
@@ -15,12 +15,14 @@ makedocs(
1515
"Writing Good Rules" => "writing_good_rules.md",
1616
"API" => "api.md",
1717
],
18+
strict=true,
19+
checkdocs=:exports,
1820
)
1921

2022
const repo = "github.com/JuliaDiff/ChainRulesCore.jl.git"
2123
const PR = get(ENV, "TRAVIS_PULL_REQUEST", "false")
2224
if PR == "false"
23-
# Normal case, only deply docs if merging to master or release tagged
25+
# Normal case, only deploy docs if merging to master or release tagged
2426
deploydocs(repo=repo)
2527
else
2628
@info "Deploying review docs for PR #$PR"

docs/src/index.md

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ This was once how all neural network code worked.
288288

289289
Using ChainRules directly also helps get a feel for it.
290290

291-
```julia
291+
```jldoctest
292292
using ChainRules
293293
294294
function foo(x)
@@ -299,47 +299,51 @@ function foo(x)
299299
end
300300
301301
#### Find dfoo/dx via rrules
302-
303-
# First the forward pass, accumulating rules
302+
#### First the forward pass, accumulating rules
304303
x = 3;
305304
a, a_pullback = rrule(sin, x);
306305
b, b_pullback = rrule(*, 2, a);
307306
c, c_pullback = rrule(asin, b)
308307
309-
# Then the backward pass calculating gradients
308+
#### Then the backward pass calculating gradients
310309
c̄ = 1; # ∂c/∂c
311310
_, b̄ = c_pullback(extern(c̄)); # ∂c/∂b
312311
_, _, ā = b_pullback(extern(b̄)); # ∂c/∂a
313312
_, x̄ = a_pullback(extern(ā)); # ∂c/∂x = ∂f/∂x
314313
extern(x̄)
315-
# -2.0638950738662625
316-
314+
# output
315+
-2.0638950738662625
316+
```
317+
```jldoctest
317318
#### Find dfoo/dx via frules
318-
319319
x = 3;
320320
ẋ = 1; # ∂x/∂x
321321
nofields = Zero(); # ∂self/∂self
322322
323323
a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x
324-
b, ḃ = frule(*, 2, nofields, unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
324+
b, ḃ = frule(*, 2, a, nofields, Zero(), unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
325325
326-
c, ċ = frule(asin, b, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
326+
c, ċ = frule(asin, b, nofields, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
327327
unthunk(ċ)
328-
# -2.0638950738662625
329-
330-
#### Find dfoo/dx via finite-differences
331-
328+
# output
329+
-2.0638950738662625
330+
```
331+
```julia
332+
#### Find dfoo/dx via FiniteDifferences.jl
332333
using FiniteDifferences
333334
central_fdm(5, 1)(foo, x)
334-
# -2.0638950738670734
335+
# output
336+
-2.0638950738670734
335337

336338
#### Find dfoo/dx via ForwardDiff.jl
337339
using ForwardDiff
338340
ForwardDiff.derivative(foo, x)
339-
# -2.0638950738662625
341+
# output
342+
-2.0638950738662625
340343

341344
#### Find dfoo/dx via Zygote.jl
342345
using Zygote
343346
Zygote.gradient(foo, x)
344-
# (-2.0638950738662625,)
347+
# output
348+
(-2.0638950738662625,)
345349
```

0 commit comments

Comments
 (0)