Skip to content

Commit e3362cb

Browse files
authored
Make @opt_out rrule(...) automatically qualify rrule namespace as ChainRulesCore.rrule (#546)
1 parent 3394de6 commit e3362cb

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

src/rule_definition_tools.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -490,34 +490,45 @@ Similar applies for [`frule`](@ref) and [`ChainRulesCore.no_frule`](@ref)
490490
For more information see the [documentation on opting out of rules](@ref opt_out).
491491
"""
492492
macro opt_out(expr)
493-
no_rule_target = _no_rule_target_rewrite!(deepcopy(expr))
493+
no_rule_target = _target_rewrite!(deepcopy(expr), true)
494+
rule_target = _target_rewrite!(deepcopy(expr), false)
494495

495496
return @strip_linenos quote
496497
$(esc(no_rule_target)) = nothing
497-
$(esc(expr)) = nothing
498+
$(esc(rule_target)) = nothing
498499
end
499500
end
500501

501-
"Rewrite method sig Expr for `rrule` to be for `no_rrule`, and `frule` to be `no_frule`."
502-
function _no_rule_target_rewrite!(expr::Expr)
502+
"""
503+
_target_rewrite!(expr::Expr, no_rule)
504+
505+
Rewrite method sig `expr` for `rrule` to be for `no_rrule` or `ChainRulesCore.rrule`
506+
(with the CRC namespace qualification), depending on the `no_rule` argument.
507+
Does the equivalent for `frule`.
508+
"""
509+
function _target_rewrite!(expr::Expr, no_rule)
503510
length(expr.args) === 0 && error("Malformed method expression. $expr")
504511
if expr.head === :call || expr.head === :where
505-
expr.args[1] = _no_rule_target_rewrite!(expr.args[1])
512+
expr.args[1] = _target_rewrite!(expr.args[1], no_rule)
506513
elseif expr.head == :(.) && expr.args[1] == :ChainRulesCore
507-
expr = _no_rule_target_rewrite!(expr.args[end])
514+
expr = _target_rewrite!(expr.args[end], no_rule)
508515
else
509516
error("Malformed method expression. $(expr)")
510517
end
511518
return expr
512519
end
513-
_no_rule_target_rewrite!(qt::QuoteNode) = _no_rule_target_rewrite!(qt.value)
514-
function _no_rule_target_rewrite!(call_target::Symbol)
515-
return if call_target == :rrule
516-
:(ChainRulesCore.no_rrule)
517-
elseif call_target == :frule
518-
:(ChainRulesCore.no_frule)
520+
_target_rewrite!(qt::QuoteNode, no_rule) = _target_rewrite!(qt.value, no_rule)
521+
function _target_rewrite!(call_target::Symbol, no_rule)
522+
return if call_target == :rrule && no_rule
523+
:($ChainRulesCore.no_rrule)
524+
elseif call_target == :rrule && !no_rule
525+
:($ChainRulesCore.rrule)
526+
elseif call_target == :frule && no_rule
527+
:($ChainRulesCore.no_frule)
528+
elseif call_target == :frule && !no_rule
529+
:($ChainRulesCore.frule)
519530
else
520-
error("Unexpected opt-out target. Exprected frule or rrule, got: $call_target")
531+
error("Unexpected opt-out target. Expected frule or rrule, got: $call_target")
521532
end
522533
end
523534

test/rule_definition_tools.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ end
285285
# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484
286286
module IsolatedModuleForTestingScoping
287287
# check that rules can be defined by macros without any additional imports
288-
using ChainRulesCore: @scalar_rule, @non_differentiable
288+
using ChainRulesCore: @scalar_rule, @non_differentiable, @opt_out
289289

290290
# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved
291291
const ChainRulesCore = nothing
@@ -303,11 +303,20 @@ module IsolatedModuleForTestingScoping
303303
my_id(x) = x
304304
@scalar_rule(my_id(x), 1.0)
305305

306+
# @opt_out
307+
first_oa(x, y) = x
308+
@scalar_rule(first_oa(x, y), (1, 0))
309+
# Declared without using the ChainRulesCore namespace qualification
310+
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/545
311+
@opt_out rrule(::typeof(first_oa), x::T, y::T) where {T<:Float16}
312+
@opt_out frule(::Any, ::typeof(first_oa), x::T, y::T) where {T<:Float16}
313+
306314
module IsolatedSubmodule
307315
# check that rules defined in isolated module without imports can be called
308316
# without errors
309317
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
310-
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
318+
using ChainRulesCore: no_rrule, no_frule
319+
using ..IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id, first_oa
311320
using Test
312321

313322
@testset "@non_differentiable" begin
@@ -339,6 +348,25 @@ module IsolatedModuleForTestingScoping
339348

340349
@test derivatives_given_output(y, my_id, x) == ((1.0,),)
341350
end
351+
352+
@testset "@optout" begin
353+
# rrule
354+
@test rrule(first_oa, Float16(3.0), Float16(4.0)) === nothing
355+
@test !isempty(
356+
Iterators.filter(methods(no_rrule)) do m
357+
m.sig <: Tuple{Any,typeof(first_oa),T,T} where {T<:Float16}
358+
end,
359+
)
360+
361+
# frule
362+
@test frule((NoTangent(), 1, 0), first_oa, Float16(3.0), Float16(4.0)) ===
363+
nothing
364+
@test !isempty(
365+
Iterators.filter(methods(no_frule)) do m
366+
m.sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<:Float16}
367+
end,
368+
)
369+
end
342370
end
343371
end
344372
#! format: on

0 commit comments

Comments
 (0)