File tree 2 files changed +24
-4
lines changed
2 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -117,7 +117,7 @@ test_scalar: relu at -0.5 | 11 11
117
117
118
118
## Testing constructors and functors (callable objects)
119
119
120
- Testing constructor and functors works as you would expect. For struct ` Foo `
120
+ Testing constructor and functors works as you would expect. For struct ` Foo ` ,
121
121
``` julia
122
122
struct Foo
123
123
a:: Float64
@@ -127,7 +127,27 @@ Base.length(::Foo) = 1
127
127
Base. iterate (f:: Foo ) = iterate (f. a)
128
128
Base. iterate (f:: Foo , state) = iterate (f. a, state)
129
129
```
130
- the ` f/rrule ` s can be tested by
130
+
131
+ after defining the constructor and functor ` f/rule ` s,
132
+
133
+ ``` julia
134
+ function ChainRulesCore. rrule (:: Type{Foo} , val) # constructor rrule
135
+ y = Foo (val)
136
+ Foo_pb (ΔFoo) = (NoTangent (), unthunk (ΔFoo). a)
137
+ return y, Foo_pb
138
+ end
139
+
140
+ function ChainRulesCore. rrule (foo:: Foo , val) # functor rrule
141
+ y = foo (val)
142
+ function foo_pb (Δ)
143
+ Δut = unthunk (Δ)
144
+ return (Tangent {Foo} (;a= Δut), Δut)
145
+ end
146
+ return y, foo_pb
147
+ end
148
+ ```
149
+
150
+ both ` f/rrule ` s can be tested by
131
151
``` julia
132
152
test_rrule (Foo, rand ()) # constructor
133
153
Original file line number Diff line number Diff line change @@ -125,7 +125,7 @@ function test_frule(
125
125
end
126
126
127
127
res = call_on_copy (frule_f, config, tangents, primals... )
128
- res === nothing && throw (MethodError (frule_f, typeof (primals)))
128
+ res === nothing && throw (MethodError (frule_f, Tuple{Core . Typeof . (primals)... } ))
129
129
@test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
130
130
Ω_ad, dΩ_ad = res
131
131
Ω = call_on_copy (primals... )
@@ -201,7 +201,7 @@ function test_rrule(
201
201
_test_inferred (rrule_f, config, primals... ; fkwargs... )
202
202
end
203
203
res = rrule_f (config, primals... ; fkwargs... )
204
- res === nothing && throw (MethodError (rrule_f, typeof (primals)))
204
+ res === nothing && throw (MethodError (rrule_f, Tuple{Core . Typeof . (primals)... } ))
205
205
y_ad, pullback = res
206
206
y = call (primals... )
207
207
test_approx (y_ad, y; isapprox_kwargs... ) # make sure primal is correct
You can’t perform that action at this time.
0 commit comments