Skip to content

Commit 984dacc

Browse files
author
Miha Zgubic
committed
Revert "Test the number of outputs in frule and rrule are correct (#168)"
This reverts commit 2bf5e3b.
1 parent 2bf5e3b commit 984dacc

File tree

4 files changed

+16
-34
lines changed

4 files changed

+16
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.7.2"
3+
version = "0.7.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ julia> using ChainRulesTestUtils;
6060
6161
julia> test_frule(two2three, 3.33, -7.77);
6262
Test Summary: | Pass Total
63-
test_frule: two2three on Float64,Float64 | 6 6
63+
test_frule: two2three on Float64,Float64 | 5 5
6464
```
6565

6666
### Testing the `rrule`
@@ -71,7 +71,7 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly
7171
```jldoctest ex; output = false
7272
julia> test_rrule(two2three, 3.33, -7.77);
7373
Test Summary: | Pass Total
74-
test_rrule: two2three on Float64,Float64 | 7 7
74+
test_rrule: two2three on Float64,Float64 | 6 6
7575
```
7676

7777
## Scalar example
@@ -98,11 +98,11 @@ call.
9898
```jldoctest ex; output = false
9999
julia> test_scalar(relu, 0.5);
100100
Test Summary: | Pass Total
101-
test_scalar: relu at 0.5 | 9 9
101+
test_scalar: relu at 0.5 | 7 7
102102
103103
julia> test_scalar(relu, -0.5);
104104
Test Summary: | Pass Total
105-
test_scalar: relu at -0.5 | 9 9
105+
test_scalar: relu at -0.5 | 7 7
106106
```
107107

108108
## Specifying Tangents

src/testers.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
6767
end
6868

6969
"""
70-
test_frule(f, args..; kwargs...)
70+
test_frule(f, inputs...; kwargs...)
7171
7272
# Arguments
7373
- `f`: Function for which the `frule` should be tested.
74-
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
74+
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
7575
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7676
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
7777
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
@@ -87,7 +87,7 @@ end
8787
"""
8888
function test_frule(
8989
f,
90-
args...;
90+
inputs...;
9191
output_tangent=Auto(),
9292
fdm=_fdm,
9393
check_inferred::Bool=true,
@@ -99,18 +99,18 @@ function test_frule(
9999
# To simplify some of the calls we make later lets group the kwargs for reuse
100100
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
101101

102-
@testset "test_frule: $f on $(_string_typeof(args))" begin
102+
@testset "test_frule: $f on $(_string_typeof(inputs))" begin
103103
_ensure_not_running_on_functor(f, "test_frule")
104104

105-
xẋs = auto_primal_and_tangent.(args)
105+
xẋs = auto_primal_and_tangent.(inputs)
106106
xs = primal.(xẋs)
107107
ẋs = tangent.(xẋs)
108108
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
109109
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
110110
end
111111
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
112112
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
113-
@test_msg "The frule should return (y, ∂y), not $res." res isa Tuple{Any,Any}
113+
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
114114
Ω_ad, dΩ_ad = res
115115
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
116116
test_approx(Ω_ad, Ω; isapprox_kwargs...)
@@ -135,11 +135,11 @@ function test_frule(
135135
end
136136

137137
"""
138-
test_rrule(f, args...; kwargs...)
138+
test_rrule(f, inputs...; kwargs...)
139139
140140
# Arguments
141141
- `f`: Function to which rule should be applied.
142-
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
142+
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
143143
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144144
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145145
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
@@ -155,7 +155,7 @@ end
155155
"""
156156
function test_rrule(
157157
f,
158-
args...;
158+
inputs...;
159159
output_tangent=Auto(),
160160
fdm=_fdm,
161161
check_inferred::Bool=true,
@@ -167,11 +167,11 @@ function test_rrule(
167167
# To simplify some of the calls we make later lets group the kwargs for reuse
168168
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
169169

170-
@testset "test_rrule: $f on $(_string_typeof(args))" begin
170+
@testset "test_rrule: $f on $(_string_typeof(inputs))" begin
171171
_ensure_not_running_on_functor(f, "test_rrule")
172172

173173
# Check correctness of evaluation.
174-
xx̄s = auto_primal_and_tangent.(args)
174+
xx̄s = auto_primal_and_tangent.(inputs)
175175
xs = primal.(xx̄s)
176176
accumulated_x̄ = tangent.(xx̄s)
177177
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
@@ -191,8 +191,6 @@ function test_rrule(
191191
∂self = ∂s[1]
192192
x̄s_ad = ∂s[2:end]
193193
@test ∂self === NoTangent() # No internal fields
194-
msg = "The pullback should return 1 cotangent for each primal input."
195-
@test_msg msg length(x̄s_ad) == length(args)
196194

197195
# Correctness testing via finite differencing.
198196
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113

test/testers.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -495,22 +495,6 @@ end
495495
@test fails(() -> test_frule(my_identity2, 2.2))
496496
@test fails(() -> test_rrule(my_identity2, 2.2))
497497
end
498-
499-
@testset "wrong number of outputs #167" begin
500-
foo(x, y) = x + 2y
501-
502-
function ChainRulesCore.frule((_, ẋ, ẏ), ::typeof(foo), x, y)
503-
return foo(x, y), ẋ + 2ẏ, NoTangent() # extra derivative
504-
end
505-
506-
function ChainRulesCore.rrule(::typeof(foo), x, y)
507-
foo_pullback(dz) = NoTangent(), dz # missing derivative
508-
return foo(x,y), foo_pullback
509-
end
510-
511-
@test fails(() -> test_frule(foo, 2.1, 2.1))
512-
@test fails(() -> test_rrule(foo, 21.0, 32.0))
513-
end
514498
end
515499

516500
@testset "Tuple primal that is not equal to differential backing" begin

0 commit comments

Comments
 (0)