Skip to content

Commit 67036a8

Browse files
committed
Constrain generated signature for nondiff frule to need tuple first arg so no ambig with ruleconfig first arg
1 parent 6befc08 commit 67036a8

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/rule_definition_tools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
403403
function (::Core.kwftype(typeof(ChainRulesCore.frule)))(
404404
@nospecialize($kwargs::Any),
405405
frule::typeof(ChainRulesCore.frule),
406-
::$RuleConfig,
406+
::Tuple,
407407
$(map(esc, primal_sig_parts)...),
408408
)
409409
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
410410
end
411411
function ChainRulesCore.frule(
412-
::$RuleConfig, $(map(esc, primal_sig_parts)...)
412+
::Tuple, $(map(esc, primal_sig_parts)...)
413413
)
414414
$(__source__)
415415
# Julia functions always only have 1 output, so return a single NoTangent()

test/rule_definition_tools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,14 @@ end
220220

221221
foo_ndc1(x) = string(x)
222222
@non_differentiable foo_ndc1(x)
223-
@test frule(AllConfig(), foo_ndc1, 2.0) == (string(2.0), NoTangent())
223+
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent())
224224
r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0)
225225
@test r1 == string(2.0)
226226
@test pb1(NoTangent()) == (NoTangent(), NoTangent())
227227

228228
foo_ndc2(x; y=0) = string(x + y)
229229
@non_differentiable foo_ndc2(x)
230-
@test frule(AllConfig(), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
230+
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
231231
r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0)
232232
@test r2 == string(6.0)
233233
@test pb2(NoTangent()) == (NoTangent(), NoTangent())

0 commit comments

Comments
 (0)