Skip to content

Commit c12f953

Browse files
authored
Merge pull request #58 from JuliaDiff/aa/checked
Add `_checked_rrule`, which errors if `rrule` returns `nothing`
2 parents 74ef7a4 + 3c534de commit c12f953

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

src/rules.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,24 @@ See also: [`frule`](@ref), [`AbstractRule`](@ref), [`@scalar_rule`](@ref)
377377
"""
378378
rrule(::Any, ::Vararg{Any}) = nothing
379379

380+
@noinline function _throw_checked_rrule_error(f, args...; kwargs...)
381+
io = IOBuffer()
382+
print(io, "can't differentiate `", f, '(')
383+
join(io, map(arg->string("::", typeof(arg)), args), ", ")
384+
if !isempty(kwargs)
385+
print(io, ";")
386+
join(io, map(((k, v),)->string(k, "=", v), kwargs), ", ")
387+
end
388+
print(io, ")`; no matching `rrule` is defined")
389+
throw(ArgumentError(String(take!(io))))
390+
end
391+
392+
function _checked_rrule(f, args...; kwargs...)
393+
r = rrule(f, args...; kwargs...)
394+
r isa Nothing && _throw_checked_rrule_error(f, args...; kwargs...)
395+
return r
396+
end
397+
380398
#####
381399
##### macros
382400
#####

src/rules/mapreduce.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@ function rrule(::typeof(map), f, xs...)
77
∂xs = ntuple(length(xs)) do i
88
Rule() do
99
map(ȳ, xs...) do ȳi, xis...
10-
r = rrule(f, xis...)
11-
if r === nothing
12-
throw(ArgumentError("can't differentiate `map` with `$f`; no `rrule` " *
13-
"is defined for `$f$xis`"))
14-
end
15-
_, ∂xis = r
10+
_, ∂xis = _checked_rrule(f, xis...)
1611
extern(∂xis[i](ȳi))
1712
end
1813
end

test/rules.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
cool(x) = x + 1
2+
cool(x, y) = x + y + 1
23

34
@testset "rules" begin
45
@testset "frule and rrule" begin
@@ -45,5 +46,14 @@ cool(x) = x + 1
4546
@test ChainRules._update!(X, Y) == (A=[2 0; 0 2], B=[4 4; 4 4])
4647
@test X.A != Y.A
4748
@test X.B != Y.B
49+
50+
try
51+
# We defined a 2-arg method for `cool` but no `rrule`
52+
ChainRules._checked_rrule(cool, 1.0, 2.0)
53+
catch e
54+
@test e isa ArgumentError
55+
@test e.msg == "can't differentiate `cool(::Float64, ::Float64)`; no " *
56+
"matching `rrule` is defined"
57+
end
4858
end
4959
end

0 commit comments

Comments
 (0)