@@ -8,8 +8,8 @@ If you want to learn about `frule`s, you should still read and understand this e
8
8
9
9
We define a struct ` Foo `
10
10
``` julia
11
- struct Foo
12
- A:: Matrix
11
+ struct Foo{T}
12
+ A:: Matrix{T}
13
13
c:: Float64
14
14
end
15
15
```
@@ -25,17 +25,27 @@ Note that field `c` is ignored in the calculation.
25
25
26
26
The ` rrule ` method for our primal computation should extend the ` ChainRulesCore.rrule ` function.
27
27
``` julia
28
- function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo , b:: AbstractArray )
28
+ function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo{T} , b:: AbstractArray ) where T
29
29
y = foo_mul (foo, b)
30
30
function foo_mul_pullback (ȳ)
31
31
f̄ = NoTangent ()
32
- f̄oo = Tangent {Foo} (; A= ȳ * b' , c= ZeroTangent ())
32
+ f̄oo = Tangent {Foo{T} } (; A= ȳ * b' , c= ZeroTangent ())
33
33
b̄ = @thunk (foo. A' * ȳ)
34
34
return f̄, f̄oo, b̄
35
35
end
36
36
return y, foo_mul_pullback
37
37
end
38
38
```
39
+
40
+ We can check this rule against a finite-differences approach using [ ` ChainRulesTestUtils ` ] ( https://github.com/JuliaDiff/ChainRulesTestUtils.jl ) :
41
+ ``` julia
42
+ julia> using ChainRulesTestUtils
43
+ julia> test_rrule (foo_mul, Foo (rand (3 , 3 ), 3.0 ), rand (3 , 3 ))
44
+ Test Summary: | Pass Total
45
+ test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
46
+ Test. DefaultTestSet (" test_rrule: foo_mul on Foo{Float64},Matrix{Float64}" , Any[], 10 , false , false )
47
+ ```
48
+
39
49
Now let's examine the rule in more detail:
40
50
``` julia
41
51
function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo , b:: AbstractArray )
@@ -84,5 +94,5 @@ The idea is that in case the tangent is not used anywhere, the computation never
84
94
Use [ ` InplaceableThunk ` ] ( @ref ) if you are interested in [ accumulating gradients inplace] (@ref grad_acc).
85
95
Note that in practice one would also ` @thunk ` the ` f̄oo.A ` tangent, but it was omitted in this example for clarity.
86
96
87
- As a final note, Since ` b ` is an ` AbstractArray ` , its tangent ` b̄ ` should be projected to the right subspace.
97
+ As a final note, since ` b ` is an ` AbstractArray ` , its tangent ` b̄ ` should be projected to the right subspace.
88
98
See the [ ` ProjectTo ` the primal subspace] (@ref projectto) section for more information and an example that motivates the projection operation.
0 commit comments