Skip to content

Commit 77748b3

Browse files
authored
Merge pull request #651 from JuliaDiff/ox/non_diff_ambig
Fix ambig with configured frule created by non_differentiable rule
2 parents 17a83d4 + 161e5f1 commit 77748b3

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.19.0"
3+
version = "1.19.1"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/rule_definition_tools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,13 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
403403
function (::Core.kwftype(typeof(ChainRulesCore.frule)))(
404404
@nospecialize($kwargs::Any),
405405
frule::typeof(ChainRulesCore.frule),
406-
@nospecialize(::Any),
406+
@nospecialize(::Tuple),
407407
$(map(esc, primal_sig_parts)...),
408408
)
409409
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
410410
end
411411
function ChainRulesCore.frule(
412-
@nospecialize(::Any), $(map(esc, primal_sig_parts)...)
412+
@nospecialize(::Tuple), $(map(esc, primal_sig_parts)...)
413413
)
414414
$(__source__)
415415
# Julia functions always only have 1 output, so return a single NoTangent()

test/rule_definition_tools.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,24 @@ end
215215
@test pullback(4.5) == (NoTangent(), NoTangent(), NoTangent())
216216
end
217217

218+
@testset "interactions with configs" begin
219+
struct AllConfig <: RuleConfig{Union{HasForwardsMode,NoReverseMode}} end
220+
221+
foo_ndc1(x) = string(x)
222+
@non_differentiable foo_ndc1(x)
223+
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc1, 2.0) == (string(2.0), NoTangent())
224+
r1, pb1 = rrule(AllConfig(), foo_ndc1, 2.0)
225+
@test r1 == string(2.0)
226+
@test pb1(NoTangent()) == (NoTangent(), NoTangent())
227+
228+
foo_ndc2(x; y=0) = string(x + y)
229+
@non_differentiable foo_ndc2(x)
230+
@test frule(AllConfig(), (NoTangent(), NoTangent()), foo_ndc2, 2.0; y=4.0) == (string(6.0), NoTangent())
231+
r2, pb2 = rrule(AllConfig(), foo_ndc2, 2.0; y=4.0)
232+
@test r2 == string(6.0)
233+
@test pb2(NoTangent()) == (NoTangent(), NoTangent())
234+
end
235+
218236
@testset "Not supported (Yet)" begin
219237
# Where clauses are not supported.
220238
@test_macro_throws(

0 commit comments

Comments
 (0)