41
41
end
42
42
function ChainRulesCore. rrule (:: typeof (identity), x)
43
43
function identity_pullback (ȳ)
44
- return (NO_FIELDS , ȳ)
44
+ return (NoTangent () , ȳ)
45
45
end
46
46
return x, identity_pullback
47
47
end
67
67
x̄_ret = InplaceableThunk (
68
68
@thunk (ȳ), ā -> (inplace_used = true ; ā .+ = ȳ)
69
69
)
70
- return (NO_FIELDS , x̄_ret)
70
+ return (NoTangent () , x̄_ret)
71
71
end
72
72
return identity (x), identity_pullback
73
73
end
93
93
function my_identity_pullback (ȳ)
94
94
# only the in-place part is incorrect
95
95
x̄_ret = InplaceableThunk (@thunk (ȳ), ā -> ā .+ = 200 .* ȳ)
96
- return (NO_FIELDS , x̄_ret)
96
+ return (NoTangent () , x̄_ret)
97
97
end
98
98
return my_identity (x), my_identity_pullback
99
99
end
106
106
@testset " check inferred" begin
107
107
ChainRulesCore. frule ((_, Δx), :: typeof (f_inferrable), x) = (x, Δx)
108
108
function ChainRulesCore. rrule (:: typeof (f_inferrable), x)
109
- f_inferrable_pullback (Δy) = (NO_FIELDS , Δy)
109
+ f_inferrable_pullback (Δy) = (NoTangent () , Δy)
110
110
return x, f_inferrable_pullback
111
111
end
112
112
123
123
return (x, x > 0 ? Float64 (Δx) : Float32 (Δx))
124
124
end
125
125
function ChainRulesCore. rrule (:: typeof (f_noninferrable_frule), x)
126
- f_noninferrable_frule_pullback (Δy) = (NO_FIELDS , Δy)
126
+ f_noninferrable_frule_pullback (Δy) = (NoTangent () , Δy)
127
127
return x, f_noninferrable_frule_pullback
128
128
end
129
129
@@ -144,10 +144,10 @@ end
144
144
ChainRulesCore. frule ((_, Δx), :: typeof (f_noninferrable_rrule), x) = (x, Δx)
145
145
function ChainRulesCore. rrule (:: typeof (f_noninferrable_rrule), x)
146
146
if x > 0
147
- f_noninferrable_rrule_pullback (Δy) = (NO_FIELDS , Δy)
147
+ f_noninferrable_rrule_pullback (Δy) = (NoTangent () , Δy)
148
148
return x, f_noninferrable_rrule_pullback
149
149
else
150
- return x, _ -> (NO_FIELDS , Δy) # this is not hit by the used point
150
+ return x, _ -> (NoTangent () , Δy) # this is not hit by the used point
151
151
end
152
152
end
153
153
167
167
@testset " check not inferred in pullback" begin
168
168
function ChainRulesCore. rrule (:: typeof (f_noninferrable_pullback), x)
169
169
function f_noninferrable_pullback_pullback (Δy)
170
- return (NO_FIELDS , x > 0 ? Float64 (Δy) : Float32 (Δy))
170
+ return (NoTangent () , x > 0 ? Float64 (Δy) : Float32 (Δy))
171
171
end
172
172
return x, f_noninferrable_pullback_pullback
173
173
end
182
182
function ChainRulesCore. rrule (:: typeof (f_noninferrable_thunk), x, y)
183
183
function f_noninferrable_thunk_pullback (Δz)
184
184
∂x = @thunk (x > 0 ? Float64 (Δz) : Float32 (Δz))
185
- return (NO_FIELDS , ∂x, Δz)
185
+ return (NoTangent () , ∂x, Δz)
186
186
end
187
187
return x + y, f_noninferrable_thunk_pullback
188
188
end
198
198
return (x > 0 ? Float64 (x) : Float32 (x), x > 0 ? Float64 (Δx) : Float32 (Δx))
199
199
end
200
200
function ChainRulesCore. rrule (:: typeof (f_inferrable_pullback_only), x)
201
- f_inferrable_pullback_only_pullback (Δy) = (NO_FIELDS , oftype (x, Δy))
201
+ f_inferrable_pullback_only_pullback (Δy) = (NoTangent () , oftype (x, Δy))
202
202
return x > 0 ? Float64 (x) : Float32 (x), f_inferrable_pullback_only_pullback
203
203
end
204
204
test_frule (f_inferrable_pullback_only, 2.0 ; check_inferred= true )
212
212
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
213
213
# in the rrule
214
214
function ChainRulesCore. rrule (:: typeof (sinconj), x)
215
- sinconj_pullback (ΔΩ) = (NO_FIELDS , conj (cos (x)) * ΔΩ)
215
+ sinconj_pullback (ΔΩ) = (NoTangent () , conj (cos (x)) * ΔΩ)
216
216
return sin (x), sinconj_pullback
217
217
end
218
218
225
225
ChainRulesCore. frule ((_, dx, dy), :: typeof (fst), x, y) = (x, dx)
226
226
function ChainRulesCore. rrule (:: typeof (fst), x, y)
227
227
function fst_pullback (Δx)
228
- return (NO_FIELDS , Δx, ZeroTangent ())
228
+ return (NoTangent () , Δx, ZeroTangent ())
229
229
end
230
230
return x, fst_pullback
231
231
end
242
242
@testset " single input, multiple output" begin
243
243
simo (x) = (x, 2 x)
244
244
function ChainRulesCore. rrule (simo, x)
245
- simo_pullback ((a, b)) = (NO_FIELDS , a .+ 2 .* b)
245
+ simo_pullback ((a, b)) = (NoTangent () , a .+ 2 .* b)
246
246
return simo (x), simo_pullback
247
247
end
248
248
function ChainRulesCore. frule ((_, ẋ), simo, x)
264
264
ChainRulesCore. frule ((_, dx), :: typeof (first), xs:: Tuple ) = (first (xs), first (dx))
265
265
function ChainRulesCore. rrule (:: typeof (first), x:: Tuple )
266
266
function first_pullback (Δx)
267
- return (NO_FIELDS , Tangent {typeof(x)} (Δx, falses (length (x) - 1 )... ))
267
+ return (NoTangent () , Tangent {typeof(x)} (Δx, falses (length (x) - 1 )... ))
268
268
end
269
269
return first (x), first_pullback
270
270
end
294
294
ChainRulesCore. frule ((_, Δx, _), :: typeof (fsymtest), x, s) = (x, Δx)
295
295
function ChainRulesCore. rrule (:: typeof (fsymtest), x, s)
296
296
function fsymtest_pullback (Δx)
297
- return NO_FIELDS , Δx, NoTangent ()
297
+ return NoTangent () , Δx, NoTangent ()
298
298
end
299
299
return x, fsymtest_pullback
300
300
end
314
314
end
315
315
function ChainRulesCore. rrule (:: typeof (futestkws), x; err= true )
316
316
function futestkws_pullback (Δx)
317
- return (NO_FIELDS , Δx)
317
+ return (NoTangent () , Δx)
318
318
end
319
319
return futestkws (x; err= err), futestkws_pullback
320
320
end
348
348
end
349
349
function ChainRulesCore. rrule (:: typeof (fbtestkws), x, y; err= true )
350
350
function fbtestkws_pullback (Δx)
351
- return (NO_FIELDS , Δx, ZeroTangent ())
351
+ return (NoTangent () , Δx, ZeroTangent ())
352
352
end
353
353
return fbtestkws (x, y; err= err), fbtestkws_pullback
354
354
end
381
381
382
382
function ChainRulesCore. rrule (:: typeof (primalapprox), x)
383
383
function primalapprox_pullback (Δx)
384
- return (NO_FIELDS , Δx)
384
+ return (NoTangent () , Δx)
385
385
end
386
386
return x + sqrt (eps (x)), primalapprox_pullback
387
387
end
@@ -391,21 +391,21 @@ end
391
391
end
392
392
393
393
@testset " frule with mutation" begin
394
- function ChainRulesCore. frule ((_, ẋ ), :: typeof (finplace!), x; y= [1 ])
394
+ function ChainRulesCore. frule ((_, ẋ ), :: typeof (finplace!), x; y= [1 ])
395
395
y[1 ] *= 2
396
396
x .*= y[1 ]
397
- ẋ .*= 2 # hardcoded to match y defined below
398
- return x, ẋ
397
+ ẋ .*= 2 # hardcoded to match y defined below
398
+ return x, ẋ
399
399
end
400
400
401
401
# these pass in tangents explictly so that we can check them after
402
402
x = randn (3 )
403
- ẋ = [4.0 , 5.0 , 6.0 ]
404
- xcopy, ẋcopy = copy (x), copy (ẋ )
403
+ ẋ = [4.0 , 5.0 , 6.0 ]
404
+ xcopy, ẋcopy = copy (x), copy (ẋ )
405
405
y = [1 , 2 ]
406
- test_frule (finplace!, x ⊢ ẋ ; fkwargs= (y= y,))
406
+ test_frule (finplace!, x ⊢ ẋ ; fkwargs= (y= y,))
407
407
@test x == xcopy
408
- @test ẋ == ẋcopy
408
+ @test ẋ == ẋcopy
409
409
@test y == [1 , 2 ]
410
410
end
411
411
450
450
∂iter = TestIterator (
451
451
∂data, Base. IteratorSize (iter), Base. IteratorEltype (iter)
452
452
)
453
- return (NO_FIELDS , ∂iter)
453
+ return (NoTangent () , ∂iter)
454
454
end
455
455
return iterfun (iter), iterfun_pullback
456
456
end
471
471
end
472
472
function ChainRulesCore. rrule (:: typeof (my_identity1), x)
473
473
function identity_pullback (ȳ)
474
- return (NO_FIELDS , ȳ)
474
+ return (NoTangent () , ȳ)
475
475
end
476
476
return 2.5 * x, identity_pullback
477
477
end
487
487
end
488
488
function ChainRulesCore. rrule (:: typeof (my_identity2), x)
489
489
function identity_pullback (ȳ)
490
- return (NO_FIELDS , 31.8 * ȳ)
490
+ return (NoTangent () , 31.8 * ȳ)
491
491
end
492
492
return x, identity_pullback
493
493
end
505
505
506
506
rev_trouble ((x, y)) = y
507
507
function ChainRulesCore. rrule (:: typeof (rev_trouble), (x, y):: P ) where {P}
508
- rev_trouble_pullback (ȳ) = (NO_FIELDS , Tangent {P} (ZeroTangent (), ȳ))
508
+ rev_trouble_pullback (ȳ) = (NoTangent () , Tangent {P} (ZeroTangent (), ȳ))
509
509
return y, rev_trouble_pullback
510
510
end
511
511
test_rrule (rev_trouble, (3 , 3.0 ) ⊢ Tangent {Tuple{Int,Float64}} (ZeroTangent (), 1.0 ))
517
517
function foo_pullback (Δy)
518
518
da = zeros (size (a))
519
519
da[i] = Δy
520
- return NO_FIELDS , da, ZeroTangent ()
520
+ return NoTangent () , da, ZeroTangent ()
521
521
end
522
522
return foo (a, i), foo_pullback
523
523
end
0 commit comments