Skip to content

Commit 203fce0

Browse files
authored
replace NO_FIELDS with NoTangent() (#155)
1 parent c405d75 commit 203fce0

File tree

6 files changed

+51
-51
lines changed

6 files changed

+51
-51
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.15"
3+
version = "0.7.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.9.44"
14+
ChainRulesCore = "0.10"
1515
Compat = "3"
1616
FiniteDifferences = "0.12"
1717
julia = "1"

docs/Manifest.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
14+
git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.44"
16+
version = "0.10.1"
1717

1818
[[ChainRulesTestUtils]]
1919
deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"]
2020
path = ".."
2121
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
22-
version = "0.6.13"
22+
version = "0.7.0"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -57,9 +57,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
5757

5858
[[FiniteDifferences]]
5959
deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "StaticArrays"]
60-
git-tree-sha1 = "8662836e29702fdfdb1b90cbe4162e31b94f1e51"
60+
git-tree-sha1 = "f8c8e287c1d68abc2719ad58fb39de9f6c0d71b1"
6161
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
62-
version = "0.12.7"
62+
version = "0.12.10"
6363

6464
[[IOCapture]]
6565
deps = ["Logging"]
@@ -167,9 +167,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
167167

168168
[[StaticArrays]]
169169
deps = ["LinearAlgebra", "Random", "Statistics"]
170-
git-tree-sha1 = "a1f226ebe197578c25fcf948bfff3d0d12f2ff20"
170+
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
171171
uuid = "90137ffa-7385-5640-81b9-e52037218182"
172-
version = "1.2.1"
172+
version = "1.2.2"
173173

174174
[[Statistics]]
175175
deps = ["LinearAlgebra", "SparseArrays"]

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ and `rrule`
3737
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
3838
y = two2three(x1, x2)
3939
function two2three_pullback(Ȳ)
40-
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
40+
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
4141
end
4242
return y, two2three_pullback
4343
end

src/testers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ function test_frule(
106106
xs = primal.(xẋs)
107107
ẋs = tangent.(xẋs)
108108
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
109-
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
109+
_test_inferred(frule, (NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
110110
end
111-
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
111+
res = frule((NoTangent(), deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
112112
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
113113
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
114114
Ω_ad, dΩ_ad = res
@@ -190,7 +190,7 @@ function test_rrule(
190190
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
191191
∂self = ∂s[1]
192192
x̄s_ad = ∂s[2:end]
193-
@test ∂self === NO_FIELDS # No internal fields
193+
@test ∂self === NoTangent() # No internal fields
194194

195195
# Correctness testing via finite differencing.
196196
# TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113

test/deprecated.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
end
3434
function ChainRulesCore.rrule(::typeof(identity), x)
3535
function identity_pullback(ȳ)
36-
return (NO_FIELDS, ȳ)
36+
return (NoTangent(), ȳ)
3737
end
3838
return x, identity_pullback
3939
end
@@ -53,7 +53,7 @@ end
5353
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
5454
# in the rrule
5555
function ChainRulesCore.rrule(::typeof(sinconj), x)
56-
sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ)
56+
sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ)
5757
return sin(x), sinconj_pullback
5858
end
5959

@@ -66,7 +66,7 @@ end
6666
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
6767
function ChainRulesCore.rrule(::typeof(fst), x, y)
6868
function fst_pullback(Δx)
69-
return (NO_FIELDS, Δx, ZeroTangent())
69+
return (NoTangent(), Δx, ZeroTangent())
7070
end
7171
return x, fst_pullback
7272
end
@@ -83,7 +83,7 @@ end
8383
@testset "single input, multiple output" begin
8484
simo(x) = (x, 2x)
8585
function ChainRulesCore.rrule(simo, x)
86-
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
86+
simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b)
8787
return simo(x), simo_pullback
8888
end
8989
function ChainRulesCore.frule((_, ẋ), simo, x)
@@ -106,7 +106,7 @@ end
106106
ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx))
107107
function ChainRulesCore.rrule(::typeof(first), x::Tuple)
108108
function first_pullback(Δx)
109-
return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
109+
return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
110110
end
111111
return first(x), first_pullback
112112
end
@@ -142,7 +142,7 @@ end
142142
ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx)
143143
function ChainRulesCore.rrule(::typeof(fsymtest), x, s)
144144
function fsymtest_pullback(Δx)
145-
return NO_FIELDS, Δx, NoTangent()
145+
return NoTangent(), Δx, NoTangent()
146146
end
147147
return x, fsymtest_pullback
148148
end
@@ -164,7 +164,7 @@ end
164164
end
165165
function ChainRulesCore.rrule(::typeof(futestkws), x; err=true)
166166
function futestkws_pullback(Δx)
167-
return (NO_FIELDS, Δx)
167+
return (NoTangent(), Δx)
168168
end
169169
return futestkws(x; err=err), futestkws_pullback
170170
end
@@ -198,7 +198,7 @@ end
198198
end
199199
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true)
200200
function fbtestkws_pullback(Δx)
201-
return (NO_FIELDS, Δx, ZeroTangent())
201+
return (NoTangent(), Δx, ZeroTangent())
202202
end
203203
return fbtestkws(x, y; err=err), fbtestkws_pullback
204204
end

test/testers.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
end
4242
function ChainRulesCore.rrule(::typeof(identity), x)
4343
function identity_pullback(ȳ)
44-
return (NO_FIELDS, ȳ)
44+
return (NoTangent(), ȳ)
4545
end
4646
return x, identity_pullback
4747
end
@@ -67,7 +67,7 @@ end
6767
x̄_ret = InplaceableThunk(
6868
@thunk(ȳ), ā -> (inplace_used = true; ā .+= ȳ)
6969
)
70-
return (NO_FIELDS, x̄_ret)
70+
return (NoTangent(), x̄_ret)
7171
end
7272
return identity(x), identity_pullback
7373
end
@@ -93,7 +93,7 @@ end
9393
function my_identity_pullback(ȳ)
9494
# only the in-place part is incorrect
9595
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
96-
return (NO_FIELDS, x̄_ret)
96+
return (NoTangent(), x̄_ret)
9797
end
9898
return my_identity(x), my_identity_pullback
9999
end
@@ -106,7 +106,7 @@ end
106106
@testset "check inferred" begin
107107
ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable), x) = (x, Δx)
108108
function ChainRulesCore.rrule(::typeof(f_inferrable), x)
109-
f_inferrable_pullback(Δy) = (NO_FIELDS, Δy)
109+
f_inferrable_pullback(Δy) = (NoTangent(), Δy)
110110
return x, f_inferrable_pullback
111111
end
112112

@@ -123,7 +123,7 @@ end
123123
return (x, x > 0 ? Float64(Δx) : Float32(Δx))
124124
end
125125
function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x)
126-
f_noninferrable_frule_pullback(Δy) = (NO_FIELDS, Δy)
126+
f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy)
127127
return x, f_noninferrable_frule_pullback
128128
end
129129

@@ -144,10 +144,10 @@ end
144144
ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_rrule), x) = (x, Δx)
145145
function ChainRulesCore.rrule(::typeof(f_noninferrable_rrule), x)
146146
if x > 0
147-
f_noninferrable_rrule_pullback(Δy) = (NO_FIELDS, Δy)
147+
f_noninferrable_rrule_pullback(Δy) = (NoTangent(), Δy)
148148
return x, f_noninferrable_rrule_pullback
149149
else
150-
return x, _ -> (NO_FIELDS, Δy) # this is not hit by the used point
150+
return x, _ -> (NoTangent(), Δy) # this is not hit by the used point
151151
end
152152
end
153153

@@ -167,7 +167,7 @@ end
167167
@testset "check not inferred in pullback" begin
168168
function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x)
169169
function f_noninferrable_pullback_pullback(Δy)
170-
return (NO_FIELDS, x > 0 ? Float64(Δy) : Float32(Δy))
170+
return (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy))
171171
end
172172
return x, f_noninferrable_pullback_pullback
173173
end
@@ -182,7 +182,7 @@ end
182182
function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y)
183183
function f_noninferrable_thunk_pullback(Δz)
184184
∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz))
185-
return (NO_FIELDS, ∂x, Δz)
185+
return (NoTangent(), ∂x, Δz)
186186
end
187187
return x + y, f_noninferrable_thunk_pullback
188188
end
@@ -198,7 +198,7 @@ end
198198
return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx))
199199
end
200200
function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x)
201-
f_inferrable_pullback_only_pullback(Δy) = (NO_FIELDS, oftype(x, Δy))
201+
f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy))
202202
return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback
203203
end
204204
test_frule(f_inferrable_pullback_only, 2.0; check_inferred=true)
@@ -212,7 +212,7 @@ end
212212
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
213213
# in the rrule
214214
function ChainRulesCore.rrule(::typeof(sinconj), x)
215-
sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ)
215+
sinconj_pullback(ΔΩ) = (NoTangent(), conj(cos(x)) * ΔΩ)
216216
return sin(x), sinconj_pullback
217217
end
218218

@@ -225,7 +225,7 @@ end
225225
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
226226
function ChainRulesCore.rrule(::typeof(fst), x, y)
227227
function fst_pullback(Δx)
228-
return (NO_FIELDS, Δx, ZeroTangent())
228+
return (NoTangent(), Δx, ZeroTangent())
229229
end
230230
return x, fst_pullback
231231
end
@@ -242,7 +242,7 @@ end
242242
@testset "single input, multiple output" begin
243243
simo(x) = (x, 2x)
244244
function ChainRulesCore.rrule(simo, x)
245-
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
245+
simo_pullback((a, b)) = (NoTangent(), a .+ 2 .* b)
246246
return simo(x), simo_pullback
247247
end
248248
function ChainRulesCore.frule((_, ẋ), simo, x)
@@ -264,7 +264,7 @@ end
264264
ChainRulesCore.frule((_, dx), ::typeof(first), xs::Tuple) = (first(xs), first(dx))
265265
function ChainRulesCore.rrule(::typeof(first), x::Tuple)
266266
function first_pullback(Δx)
267-
return (NO_FIELDS, Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
267+
return (NoTangent(), Tangent{typeof(x)}(Δx, falses(length(x) - 1)...))
268268
end
269269
return first(x), first_pullback
270270
end
@@ -294,7 +294,7 @@ end
294294
ChainRulesCore.frule((_, Δx, _), ::typeof(fsymtest), x, s) = (x, Δx)
295295
function ChainRulesCore.rrule(::typeof(fsymtest), x, s)
296296
function fsymtest_pullback(Δx)
297-
return NO_FIELDS, Δx, NoTangent()
297+
return NoTangent(), Δx, NoTangent()
298298
end
299299
return x, fsymtest_pullback
300300
end
@@ -314,7 +314,7 @@ end
314314
end
315315
function ChainRulesCore.rrule(::typeof(futestkws), x; err=true)
316316
function futestkws_pullback(Δx)
317-
return (NO_FIELDS, Δx)
317+
return (NoTangent(), Δx)
318318
end
319319
return futestkws(x; err=err), futestkws_pullback
320320
end
@@ -348,7 +348,7 @@ end
348348
end
349349
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err=true)
350350
function fbtestkws_pullback(Δx)
351-
return (NO_FIELDS, Δx, ZeroTangent())
351+
return (NoTangent(), Δx, ZeroTangent())
352352
end
353353
return fbtestkws(x, y; err=err), fbtestkws_pullback
354354
end
@@ -381,7 +381,7 @@ end
381381

382382
function ChainRulesCore.rrule(::typeof(primalapprox), x)
383383
function primalapprox_pullback(Δx)
384-
return (NO_FIELDS, Δx)
384+
return (NoTangent(), Δx)
385385
end
386386
return x + sqrt(eps(x)), primalapprox_pullback
387387
end
@@ -391,21 +391,21 @@ end
391391
end
392392

393393
@testset "frule with mutation" begin
394-
function ChainRulesCore.frule((_, ), ::typeof(finplace!), x; y=[1])
394+
function ChainRulesCore.frule((_, ), ::typeof(finplace!), x; y=[1])
395395
y[1] *= 2
396396
x .*= y[1]
397-
.*= 2 # hardcoded to match y defined below
398-
return x,
397+
.*= 2 # hardcoded to match y defined below
398+
return x,
399399
end
400400

401401
# these pass in tangents explictly so that we can check them after
402402
x = randn(3)
403-
= [4.0, 5.0, 6.0]
404-
xcopy, ẋcopy = copy(x), copy()
403+
= [4.0, 5.0, 6.0]
404+
xcopy, ẋcopy = copy(x), copy()
405405
y = [1, 2]
406-
test_frule(finplace!, x ; fkwargs=(y=y,))
406+
test_frule(finplace!, x ; fkwargs=(y=y,))
407407
@test x == xcopy
408-
@test == ẋcopy
408+
@test == ẋcopy
409409
@test y == [1, 2]
410410
end
411411

@@ -450,7 +450,7 @@ end
450450
∂iter = TestIterator(
451451
∂data, Base.IteratorSize(iter), Base.IteratorEltype(iter)
452452
)
453-
return (NO_FIELDS, ∂iter)
453+
return (NoTangent(), ∂iter)
454454
end
455455
return iterfun(iter), iterfun_pullback
456456
end
@@ -471,7 +471,7 @@ end
471471
end
472472
function ChainRulesCore.rrule(::typeof(my_identity1), x)
473473
function identity_pullback(ȳ)
474-
return (NO_FIELDS, ȳ)
474+
return (NoTangent(), ȳ)
475475
end
476476
return 2.5 * x, identity_pullback
477477
end
@@ -487,7 +487,7 @@ end
487487
end
488488
function ChainRulesCore.rrule(::typeof(my_identity2), x)
489489
function identity_pullback(ȳ)
490-
return (NO_FIELDS, 31.8 * ȳ)
490+
return (NoTangent(), 31.8 * ȳ)
491491
end
492492
return x, identity_pullback
493493
end
@@ -505,7 +505,7 @@ end
505505

506506
rev_trouble((x, y)) = y
507507
function ChainRulesCore.rrule(::typeof(rev_trouble), (x, y)::P) where {P}
508-
rev_trouble_pullback(ȳ) = (NO_FIELDS, Tangent{P}(ZeroTangent(), ȳ))
508+
rev_trouble_pullback(ȳ) = (NoTangent(), Tangent{P}(ZeroTangent(), ȳ))
509509
return y, rev_trouble_pullback
510510
end
511511
test_rrule(rev_trouble, (3, 3.0) Tangent{Tuple{Int,Float64}}(ZeroTangent(), 1.0))
@@ -517,7 +517,7 @@ end
517517
function foo_pullback(Δy)
518518
da = zeros(size(a))
519519
da[i] = Δy
520-
return NO_FIELDS, da, ZeroTangent()
520+
return NoTangent(), da, ZeroTangent()
521521
end
522522
return foo(a, i), foo_pullback
523523
end

0 commit comments

Comments
 (0)