@@ -67,11 +67,11 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
67
67
end
68
68
69
69
"""
70
- test_frule(f, inputs. ..; kwargs...)
70
+ test_frule(f, args ..; kwargs...)
71
71
72
72
# Arguments
73
73
- `f`: Function for which the `frule` should be tested.
74
- - `inputs ` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
74
+ - `args ` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
75
75
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
76
76
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
77
77
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
87
87
"""
88
88
function test_frule (
89
89
f,
90
- inputs ... ;
90
+ args ... ;
91
91
output_tangent= Auto (),
92
92
fdm= _fdm,
93
93
check_inferred:: Bool = true ,
@@ -99,18 +99,18 @@ function test_frule(
99
99
# To simplify some of the calls we make later lets group the kwargs for reuse
100
100
isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
101
101
102
- @testset " test_frule: $f on $(_string_typeof (inputs )) " begin
102
+ @testset " test_frule: $f on $(_string_typeof (args )) " begin
103
103
_ensure_not_running_on_functor (f, " test_frule" )
104
104
105
- xẋs = auto_primal_and_tangent .(inputs )
105
+ xẋs = auto_primal_and_tangent .(args )
106
106
xs = primal .(xẋs)
107
107
ẋs = tangent .(xẋs)
108
108
if check_inferred && _is_inferrable (f, deepcopy (xs)... ; deepcopy (fkwargs)... )
109
109
_test_inferred (frule, (NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
110
110
end
111
111
res = frule ((NoTangent (), deepcopy (ẋs)... ), f, deepcopy (xs)... ; deepcopy (fkwargs)... )
112
112
res === nothing && throw (MethodError (frule, typeof ((f, xs... ))))
113
- res isa Tuple || error ( " The frule should return (y, ∂y), not $res ." )
113
+ @test_msg " The frule should return (y, ∂y), not $res ." res isa Tuple{Any,Any}
114
114
Ω_ad, dΩ_ad = res
115
115
Ω = f (deepcopy (xs)... ; deepcopy (fkwargs)... )
116
116
test_approx (Ω_ad, Ω; isapprox_kwargs... )
@@ -135,11 +135,11 @@ function test_frule(
135
135
end
136
136
137
137
"""
138
- test_rrule(f, inputs ...; kwargs...)
138
+ test_rrule(f, args ...; kwargs...)
139
139
140
140
# Arguments
141
141
- `f`: Function to which rule should be applied.
142
- - `inputs ` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
142
+ - `args ` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
143
143
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
144
144
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
145
145
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
155
155
"""
156
156
function test_rrule (
157
157
f,
158
- inputs ... ;
158
+ args ... ;
159
159
output_tangent= Auto (),
160
160
fdm= _fdm,
161
161
check_inferred:: Bool = true ,
@@ -167,11 +167,11 @@ function test_rrule(
167
167
# To simplify some of the calls we make later lets group the kwargs for reuse
168
168
isapprox_kwargs = (; rtol= rtol, atol= atol, kwargs... )
169
169
170
- @testset " test_rrule: $f on $(_string_typeof (inputs )) " begin
170
+ @testset " test_rrule: $f on $(_string_typeof (args )) " begin
171
171
_ensure_not_running_on_functor (f, " test_rrule" )
172
172
173
173
# Check correctness of evaluation.
174
- xx̄s = auto_primal_and_tangent .(inputs )
174
+ xx̄s = auto_primal_and_tangent .(args )
175
175
xs = primal .(xx̄s)
176
176
accumulated_x̄ = tangent .(xx̄s)
177
177
if check_inferred && _is_inferrable (f, xs... ; fkwargs... )
@@ -191,6 +191,8 @@ function test_rrule(
191
191
∂self = ∂s[1 ]
192
192
x̄s_ad = ∂s[2 : end ]
193
193
@test ∂self === NoTangent () # No internal fields
194
+ msg = " The pullback should return 1 cotangent for each primal input."
195
+ @test_msg msg length (x̄s_ad) == length (args)
194
196
195
197
# Correctness testing via finite differencing.
196
198
# TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
0 commit comments