Skip to content

Commit 051c10e

Browse files
authored
fix example (#594)
1 parent 9c8fcd2 commit 051c10e

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

docs/src/rule_author/example.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ If you want to learn about `frule`s, you should still read and understand this e
88

99
We define a struct `Foo`
1010
```julia
11-
struct Foo
12-
A::Matrix
11+
struct Foo{T}
12+
A::Matrix{T}
1313
c::Float64
1414
end
1515
```
@@ -25,17 +25,27 @@ Note that field `c` is ignored in the calculation.
2525

2626
The `rrule` method for our primal computation should extend the `ChainRulesCore.rrule` function.
2727
```julia
28-
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
28+
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo{T}, b::AbstractArray) where T
2929
y = foo_mul(foo, b)
3030
function foo_mul_pullback(ȳ)
3131
= NoTangent()
32-
f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
32+
f̄oo = Tangent{Foo{T}}(; A=* b', c=ZeroTangent())
3333
= @thunk(foo.A' * ȳ)
3434
return f̄, f̄oo, b̄
3535
end
3636
return y, foo_mul_pullback
3737
end
3838
```
39+
40+
We can check this rule against a finite-differences approach using [`ChainRulesTestUtils`](https://github.com/JuliaDiff/ChainRulesTestUtils.jl):
41+
```julia
42+
julia> using ChainRulesTestUtils
43+
julia> test_rrule(foo_mul, Foo(rand(3, 3), 3.0), rand(3, 3))
44+
Test Summary: | Pass Total
45+
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
46+
Test.DefaultTestSet("test_rrule: foo_mul on Foo{Float64},Matrix{Float64}", Any[], 10, false, false)
47+
```
48+
3949
Now let's examine the rule in more detail:
4050
```julia
4151
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
@@ -84,5 +94,5 @@ The idea is that in case the tangent is not used anywhere, the computation never
8494
Use [`InplaceableThunk`](@ref) if you are interested in [accumulating gradients inplace](@ref grad_acc).
8595
Note that in practice one would also `@thunk` the `f̄oo.A` tangent, but it was omitted in this example for clarity.
8696

87-
As a final note, Since `b` is an `AbstractArray`, its tangent `` should be projected to the right subspace.
97+
As a final note, since `b` is an `AbstractArray`, its tangent `` should be projected to the right subspace.
8898
See the [`ProjectTo` the primal subspace](@ref projectto) section for more information and an example that motivates the projection operation.

0 commit comments

Comments
 (0)