You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/rule_author/example.md
+8-8Lines changed: 8 additions & 8 deletions
Original file line number
Diff line number
Diff line change
@@ -58,36 +58,36 @@ Read more about writing rules for constructors and callable objects [here](@ref
58
58
The `rrule` returns the primal result `y`, and the pullback function.
59
59
It is a _very_ good idea to name your pullback function, so that they are helpful when appearing in the stacktrace.
60
60
```julia
61
-
y =foo_mul(foo, b)
61
+
y =foo_mul(foo, b)
62
62
```
63
63
Computes the primal result.
64
64
It is possible to change the primal computation so that work can be shared between the primal and the pullback.
65
65
See e.g. [the rule for `sort`](https://github.com/JuliaDiff/ChainRules.jl/blob/a75193768775975fac5578c89d1e5f50d7f358c2/src/rulesets/Base/sort.jl#L19-L35), where the sorting is done only once.
66
66
```julia
67
-
functionfoo_mul_pullback(ȳ)
68
-
...
69
-
return f̄, f̄oo, b̄
70
-
end
67
+
functionfoo_mul_pullback(ȳ)
68
+
...
69
+
return f̄, f̄oo, b̄
70
+
end
71
71
```
72
72
The pullback function takes in the tangent of the primal output (`ȳ`) and returns the tangents of the primal inputs.
73
73
Note that it returns a tangent for the primal function in addition to the tangents of primal arguments.
74
74
75
75
Finally, computing the tangents of primal inputs:
76
76
```julia
77
-
f̄ =NoTangent()
77
+
f̄ =NoTangent()
78
78
```
79
79
The function `foo_mul` has no fields (i.e. it is not a closure) and can not be perturbed.
80
80
Therefore its tangent (`f̄`) is a `NoTangent`.
81
81
```julia
82
-
f̄oo =Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
82
+
f̄oo =Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
83
83
```
84
84
The struct `foo::Foo` gets a `Tangent{Foo}` structural tangent, which stores the tangents of fields of `foo`.
85
85
86
86
The tangent of the field `A` is `ȳ * b'`,
87
87
88
88
The tangent of the field `c` is `ZeroTangent()`, because `c` can be perturbed but has no effect on the primal output.
89
89
```julia
90
-
b̄ =@thunk(foo.A'* ȳ)
90
+
b̄ =@thunk(foo.A'* ȳ)
91
91
```
92
92
The tangent of `b` is `foo.A' * ȳ`, but we have wrapped it into a `Thunk`, a tangent type that represents delayed computation.
93
93
The idea is that in case the tangent is not used anywhere, the computation never happens.
0 commit comments