Skip to content

Commit 8c7f7ed

Browse files
authored
Add ignore_derivatives (Zygote's dropgrad and ignore) (#229)
1 parent 2ec2549 commit 8c7f7ed

File tree

9 files changed

+167
-8
lines changed

9 files changed

+167
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.6.0"
3+
version = "1.7.0"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/Manifest.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "1.0.0-DEV"
16+
version = "1.7.0"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20-
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
20+
git-tree-sha1 = "1a90210acd935f222ea19657f143004d2c2a1117"
2121
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22-
version = "3.31.0"
22+
version = "3.38.0"
2323

2424
[[Dates]]
2525
deps = ["Printf"]
@@ -73,9 +73,9 @@ version = "1.3.0"
7373

7474
[[JSON]]
7575
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
76-
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
76+
git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37"
7777
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
78-
version = "0.21.1"
78+
version = "0.21.2"
7979

8080
[[LibCURL]]
8181
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
@@ -122,9 +122,9 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
122122

123123
[[Parsers]]
124124
deps = ["Dates"]
125-
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
125+
git-tree-sha1 = "9d8c00ef7a8d110787ff6f170579846f776133a9"
126126
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
127-
version = "1.1.0"
127+
version = "2.0.4"
128128

129129
[[Pkg]]
130130
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ makedocs(
5656
"Gradient Accumulation" => "gradient_accumulation.md",
5757
"Usage in AD" => "use_in_ad_system.md",
5858
"Converting ZygoteRules" => "converting_zygoterules.md",
59+
"Tips for making packages work with AD" => "tips_for_packages.md",
5960
"Design" => [
6061
"Changing the Primal" => "design/changing_the_primal.md",
6162
"Many Differential Types" => "design/many_differentials.md",

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@ Private = false
4646
ProjectTo
4747
```
4848

49+
## Ignoring gradients
50+
```@docs
51+
ignore_derivatives
52+
@ignore_derivatives
53+
```
54+
4955
## Internal
5056
```@docs
5157
ChainRulesCore.AbstractTangent

docs/src/tips_for_packages.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Tips for making your package work with AD
2+
3+
## Ignoring gradients for certain expressions
4+
5+
There exists code that is not meant to be differentiated through, for example logging.
6+
In some cases, AD systems might work perfectly well with that code, but in others they might not.
7+
A convenience function `ignore_derivatives` is provided to get around this issue.
8+
It captures the functionality of both `Zygote.ignore` and `Zygote.dropgrad`.
9+
10+
For example, Zygote does not support mutation, so it will break if you try to store intermediate values as in the following example:
11+
```julia
12+
somes = []
13+
things = []
14+
15+
function loss(x, y)
16+
some = f(x, y)
17+
thing = g(x)
18+
19+
# log
20+
push!(somes, some)
21+
push!(things, thing)
22+
23+
return some + thing
24+
end
25+
```
26+
27+
It is possible to get around this by using the `ignore_derivatives` function
28+
```julia
29+
ignore_derivatives() do
30+
push!(somes, some)
31+
push!(things, thing)
32+
end
33+
```
34+
or using a macro for one-liners
35+
```julia
36+
@ignore_derivatives push!(things, thing)
37+
```
38+
39+
It is also possible to use this on individual objects, e.g.
40+
```julia
41+
ignore_derivatives(a) + b
42+
```
43+
will ignore the gradients for `a` only.
44+
45+
Passing in instances of functors (callable structs), `ignore_derivatives(functor)`, will make them behave like normal structs, i.e. propagate without being called and dropping their gradients.
46+
If you want to call a functor in the primal computation, wrap it in a closure: `ignore_derivatives(() -> functor())`

src/ChainRulesCore.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ export frule_via_ad, rrule_via_ad
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1414
export ProjectTo, canonicalize, unthunk # differential operations
1515
export add!! # gradient accumulation operations
16+
export ignore_derivatives, @ignore_derivatives
1617
# differentials
1718
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1819

@@ -32,6 +33,7 @@ include("projection.jl")
3233
include("config.jl")
3334
include("rules.jl")
3435
include("rule_definition_tools.jl")
36+
include("ignore_derivatives.jl")
3537

3638
include("deprecated.jl")
3739

src/ignore_derivatives.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
ignore_derivatives(f::Function)
3+
4+
Tells the AD system to ignore the gradients of the wrapped closure. The primal computation
5+
(forward pass) is executed normally.
6+
7+
```julia
8+
ignore_derivatives() do
9+
value = rand()
10+
push!(collection, value)
11+
end
12+
```
13+
14+
Using this incorrectly could lead to incorrect gradients.
15+
For example, the following function will have zero gradients with respect to its argument:
16+
```julia
17+
function wrong_grads(x)
18+
y = ones(3)
19+
ignore_derivatives() do
20+
push!(y, x)
21+
end
22+
return sum(y)
23+
end
24+
```
25+
"""
26+
ignore_derivatives(f::Function) = f()
27+
28+
"""
29+
ignore_derivatives(x)
30+
31+
Tells the AD system to ignore the gradients of the argument. Can be used to avoid
32+
unnecessary computation of gradients.
33+
34+
```julia
35+
ignore_derivatives(x) * w
36+
```
37+
"""
38+
ignore_derivatives(x) = x
39+
40+
@non_differentiable ignore_derivatives(f)
41+
42+
"""
43+
@ignore_derivatives (...)
44+
45+
Tells the AD system to ignore the expression. Equivalent to `ignore_derivatives() do (...) end`.
46+
"""
47+
macro ignore_derivatives(ex)
48+
return :(ChainRulesCore.ignore_derivatives() do
49+
$(esc(ex))
50+
end)
51+
end

test/ignore_derivatives.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
struct MyFunctor
2+
a::Float64
3+
end
4+
(mf::MyFunctor)(b) = mf.a + b
5+
6+
@testset "ignore_derivatives.jl" begin
7+
@testset "function" begin
8+
f() = return 4.0
9+
10+
y, ẏ = frule((1.0, ), ignore_derivatives, f)
11+
@test y == f()
12+
@test== NoTangent()
13+
14+
y, pb = rrule(ignore_derivatives, f)
15+
@test y == f()
16+
@test pb(1.0) == (NoTangent(), NoTangent())
17+
end
18+
19+
@testset "argument" begin
20+
arg = 2.1
21+
22+
y, ẏ = frule((1.0, ), ignore_derivatives, arg)
23+
@test y == arg
24+
@test== NoTangent()
25+
26+
y, pb = rrule(ignore_derivatives, arg)
27+
@test y == arg
28+
@test pb(1.0) == (NoTangent(), NoTangent())
29+
end
30+
31+
@testset "functor" begin
32+
mf = MyFunctor(1.0)
33+
34+
# as an argument
35+
y, ẏ = frule((1.0,), ignore_derivatives, mf)
36+
@test y == mf
37+
@test== NoTangent()
38+
39+
y, pb = rrule(ignore_derivatives, mf)
40+
@test y == mf
41+
@test pb(1.0) == (NoTangent(), NoTangent())
42+
43+
# when called
44+
y, ẏ = frule((1.0,), ignore_derivatives, ()->mf(3.0))
45+
@test y == mf(3.0)
46+
@test== NoTangent()
47+
48+
y, pb = rrule(ignore_derivatives, ()->mf(3.0))
49+
@test y == mf(3.0)
50+
@test pb(1.0) == (NoTangent(), NoTangent())
51+
end
52+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using Test
2121
include("rules.jl")
2222
include("rule_definition_tools.jl")
2323
include("config.jl")
24+
include("ignore_derivatives.jl")
2425

2526
include("deprecated.jl")
2627
end

0 commit comments

Comments
 (0)