285
285
# workaround for https://github.com/domluna/JuliaFormatter.jl/issues/484
286
286
module IsolatedModuleForTestingScoping
287
287
# check that rules can be defined by macros without any additional imports
288
- using ChainRulesCore: @scalar_rule , @non_differentiable
288
+ using ChainRulesCore: @scalar_rule , @non_differentiable , @opt_out
289
289
290
290
# ensure that functions, types etc. in module `ChainRulesCore` can't be resolved
291
291
const ChainRulesCore = nothing
@@ -303,11 +303,20 @@ module IsolatedModuleForTestingScoping
303
303
my_id (x) = x
304
304
@scalar_rule (my_id (x), 1.0 )
305
305
306
+ # @opt_out
307
+ first_oa (x, y) = x
308
+ @scalar_rule (first_oa (x, y), (1 , 0 ))
309
+ # Declared without using the ChainRulesCore namespace qualification
310
+ # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/545
311
+ @opt_out rrule (:: typeof (first_oa), x:: T , y:: T ) where {T<: Float16 }
312
+ @opt_out frule (:: Any , :: typeof (first_oa), x:: T , y:: T ) where {T<: Float16 }
313
+
306
314
module IsolatedSubmodule
307
315
# check that rules defined in isolated module without imports can be called
308
316
# without errors
309
317
using ChainRulesCore: frule, rrule, ZeroTangent, NoTangent, derivatives_given_output
310
- using .. IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id
318
+ using ChainRulesCore: no_rrule, no_frule
319
+ using .. IsolatedModuleForTestingScoping: fixed, fixed_kwargs, my_id, first_oa
311
320
using Test
312
321
313
322
@testset " @non_differentiable" begin
@@ -339,6 +348,25 @@ module IsolatedModuleForTestingScoping
339
348
340
349
@test derivatives_given_output (y, my_id, x) == ((1.0 ,),)
341
350
end
351
+
352
+ @testset " @optout" begin
353
+ # rrule
354
+ @test rrule (first_oa, Float16 (3.0 ), Float16 (4.0 )) === nothing
355
+ @test ! isempty (
356
+ Iterators. filter (methods (no_rrule)) do m
357
+ m. sig <: Tuple{Any,typeof(first_oa),T,T} where {T<: Float16 }
358
+ end ,
359
+ )
360
+
361
+ # frule
362
+ @test frule ((NoTangent (), 1 , 0 ), first_oa, Float16 (3.0 ), Float16 (4.0 )) ===
363
+ nothing
364
+ @test ! isempty (
365
+ Iterators. filter (methods (no_frule)) do m
366
+ m. sig <: Tuple{Any,Any,typeof(first_oa),T,T} where {T<: Float16 }
367
+ end ,
368
+ )
369
+ end
342
370
end
343
371
end
344
372
# ! format: on
0 commit comments