Skip to content

Commit 72ca897

Browse files
committed
address comment
1 parent 16ad573 commit 72ca897

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

src/rule_definition_tools.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,10 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
418418
end
419419
end
420420

421-
function _make_pullback_for_non_differentiable(::Val{N}) where {N}
422-
Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int`
423-
function pullback_for_non_differentiable(::Any)
424-
return ntuple(Returns(NoTangent()), Val(N))
425-
end
421+
struct NonDiffPullback{N} <: Function end
422+
423+
function (pb::NonDiffPullback{N})(::Any) where {N}
424+
return ntuple(Returns(NoTangent()), Val(N))
426425
end
427426

428427
function tuple_length_expression(primal_sig_parts)
@@ -443,7 +442,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
443442
tup_len_expr = tuple_length_expression(primal_sig_parts)
444443
primal_name = first(primal_invoke.args)
445444
pullback_expr = @strip_linenos quote
446-
_make_pullback_for_non_differentiable(Val{$(tup_len_expr)}())
445+
NonDiffPullback{$(tup_len_expr)}()
447446
end
448447

449448
@gensym kwargs

test/rule_definition_tools.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,22 @@ end
4242

4343
@testset "rule_definition_tools.jl" begin
4444
@testset "@non_differentiable" begin
45-
@testset "`_make_pullback_for_non_differentiable`" begin
46-
f = ChainRulesCore._make_pullback_for_non_differentiable
47-
@testset "throws on invalid input" begin
48-
@test_throws Exception f(Val(0.0))
49-
@test_throws Exception f(Val(-1))
50-
end
45+
@testset "`NonDiffPullback`" begin
46+
NDP = ChainRulesCore.NonDiffPullback
5147
@testset "identical objects" begin
5248
for i in 0:5
53-
v = Val(i)
54-
@test f(v) === f(v)
49+
@test NDP{i}() === NDP{i}()
5550
end
5651
end
5752
@testset "correctness" begin
5853
for i in 0:5
5954
expected = ntuple((_ -> NoTangent()), i)
60-
@test f(Val(i))(:arbitrary) === expected
55+
@test NDP{i}()(:arbitrary) === expected
6156
end
6257
end
6358
@testset "dispatch" begin
6459
for i in 0:5
65-
pullback = f(Val(i))
60+
pullback = NDP{i}()
6661
@test_throws MethodError pullback()
6762
@test_throws MethodError pullback(1, 2)
6863
end

0 commit comments

Comments
 (0)