Skip to content

Commit b105a9f

Browse files
authored
Improve the error message for constructors (#254)
* show Type{Foo} rather than DataType in the rrule MethodError * add rule to the docs
1 parent 22d8446 commit b105a9f

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

docs/src/index.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ test_scalar: relu at -0.5 | 11 11
117117

118118
## Testing constructors and functors (callable objects)
119119

120-
Testing constructor and functors works as you would expect. For struct `Foo`
120+
Testing constructor and functors works as you would expect. For struct `Foo`,
121121
```julia
122122
struct Foo
123123
a::Float64
@@ -127,7 +127,27 @@ Base.length(::Foo) = 1
127127
Base.iterate(f::Foo) = iterate(f.a)
128128
Base.iterate(f::Foo, state) = iterate(f.a, state)
129129
```
130-
the `f/rrule`s can be tested by
130+
131+
after defining the constructor and functor `f/rule`s,
132+
133+
```julia
134+
function ChainRulesCore.rrule(::Type{Foo}, val) # constructor rrule
135+
y = Foo(val)
136+
Foo_pb(ΔFoo) = (NoTangent(), unthunk(ΔFoo).a)
137+
return y, Foo_pb
138+
end
139+
140+
function ChainRulesCore.rrule(foo::Foo, val) # functor rrule
141+
y = foo(val)
142+
function foo_pb(Δ)
143+
Δut = unthunk(Δ)
144+
return (Tangent{Foo}(;a=Δut), Δut)
145+
end
146+
return y, foo_pb
147+
end
148+
```
149+
150+
both `f/rrule`s can be tested by
131151
```julia
132152
test_rrule(Foo, rand()) # constructor
133153

src/testers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function test_frule(
125125
end
126126

127127
res = call_on_copy(frule_f, config, tangents, primals...)
128-
res === nothing && throw(MethodError(frule_f, typeof(primals)))
128+
res === nothing && throw(MethodError(frule_f, Tuple{Core.Typeof.(primals)...}))
129129
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
130130
Ω_ad, dΩ_ad = res
131131
Ω = call_on_copy(primals...)
@@ -201,7 +201,7 @@ function test_rrule(
201201
_test_inferred(rrule_f, config, primals...; fkwargs...)
202202
end
203203
res = rrule_f(config, primals...; fkwargs...)
204-
res === nothing && throw(MethodError(rrule_f, typeof(primals)))
204+
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
205205
y_ad, pullback = res
206206
y = call(primals...)
207207
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct

0 commit comments

Comments
 (0)