You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/rule_author/which_functions_need_rules.md
+19-9Lines changed: 19 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -27,32 +27,40 @@ Other patterns can be AD'ed through, but the backward pass performance can be gr
27
27
#### Functions which mutate arrays
28
28
For example,
29
29
```julia
30
-
functionaddone!(array)
31
-
array .+=1
32
-
returnsum(array)
30
+
functionaddone(a::AbstractArray)
31
+
b =similar(a)
32
+
b .= a .+1
33
+
returnsum(b)
33
34
end
34
35
```
35
36
complains that
36
37
```julia
37
38
julia>using Zygote
38
-
julia>gradient(addone!, a)
39
+
julia>gradient(addone, a)
39
40
ERROR: Mutating arrays is not supported
40
41
```
41
42
However, upon adding the `rrule` (restart the REPL after calling `gradient`)
42
43
```julia
43
-
function ChainRules.rrule(::typeof(addone!), a)
44
-
y =addone!(a)
45
-
functionaddone!_pullback(ȳ)
44
+
function ChainRules.rrule(::typeof(addone), a)
45
+
y =addone(a)
46
+
functionaddone_pullback(ȳ)
46
47
returnNoTangent(), ones(length(a))
47
48
end
48
-
return y, addone!_pullback
49
+
return y, addone_pullback
49
50
end
50
51
```
51
52
the gradient can be evaluated:
52
53
```julia
53
-
julia>gradient(addone!, a)
54
+
julia>gradient(addone, a)
54
55
([1.0, 1.0, 1.0],)
55
56
```
57
+
Notice that `addone(a)` mutates another array `b` internally, but **not** its input.
58
+
This is commonly done in less trivial functions, and is often what Zygote's `Mutating arrays is not supported` error is telling you,
59
+
even though you did not intend to mutate anything.
60
+
Functions which mutate their own input are much more problematic.
61
+
These are the ones named (by convention) with an exclamation mark, such as `fill!(a, x)` or `push!(a, x)`.
62
+
It is not possible to write rules which handle all uses of such a function correctly, on current Zygote.
63
+
56
64
57
65
!!! note "Why restarting REPL after calling `gradient`?"
58
66
When `gradient` is called in `Zygote` for a function with no `rrule` defined, a backward pass for the function call is generated and cached.
@@ -61,6 +69,8 @@ julia> gradient(addone!, a)
61
69
If an `rrule` is defined before the first call to `gradient` it should register the rule and use it, but that prevents comparing what happens before and after the `rrule` is defined.
62
70
To compare both versions with and without an `rrule` in the REPL simultaneously, define a function `f(x) = <body>` (no `rrule`), another function `f_cr(x) = f(x)`, and an `rrule` for `f_cr`.
63
71
72
+
Calling `Zygote.refresh()` will often have the same effect as restarting the REPL.
73
+
64
74
#### Exception handling
65
75
66
76
Zygote does not support differentiating through `try`/`catch` statements.
0 commit comments