@@ -47,7 +47,7 @@ jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...)
47
47
isZero (x) = x isa AbstractZero
48
48
49
49
# Zygote's misnamed hobbit function:
50
- function pullback (f, x... )
50
+ function value_and_pullback (f, x... )
51
51
y, b = Diffractor.∂⃖ {1} ()(f, x... )
52
52
back (dy) = map (unthunk, Base. tail (b (dy)))
53
53
y, back
171
171
172
172
# https://github.com/FluxML/Zygote.jl/issues/376
173
173
174
- _, back = pullback (x-> x[1 ]* im, randn (2 ))
174
+ _, back = value_and_pullback (x-> x[1 ]* im, randn (2 ))
175
175
@test back (1.0 )[1 ] == real ([- im, 0 ]) == [0 , 0 ]
176
176
177
177
# _droplike
@@ -187,10 +187,10 @@ end
187
187
@test_broken gradient (x -> sum (Float32[1 , x] .+ x), 4 ) == (3.0f0 ,)
188
188
189
189
# Ensure that nothings work with numeric types.
190
- _, back = pullback (getindex, randn (4 ), [1 ])
190
+ _, back = value_and_pullback (getindex, randn (4 ), [1 ])
191
191
@test back ([ZeroTangent ()]) == (zeros (4 ), NoTangent ())
192
192
# Ensure that nothings work with non-numeric types.
193
- _, back = pullback (getindex, [randn (2 ) for _ in 1 : 3 ], [1 ])
193
+ _, back = value_and_pullback (getindex, [randn (2 ) for _ in 1 : 3 ], [1 ])
194
194
@test back ([ZeroTangent ()]) == (NoTangent (), NoTangent ())
195
195
end
196
196
246
246
@test jacobicheck (x -> permutedims (x, [3 ,1 ,2 ]), rand (4 ,5 ,6 ))
247
247
@test jacobicheck (x -> PermutedDimsArray (x, (3 ,1 ,2 )), rand (4 ,5 ,6 ))
248
248
let
249
- y, back = pullback (permutedims, randn (3 ))
249
+ y, back = value_and_pullback (permutedims, randn (3 ))
250
250
@test first (back (randn (1 , 3 ))) isa Vector
251
251
end
252
252
end
@@ -311,36 +311,36 @@ end
311
311
312
312
@testset " Tuple adjoint" begin
313
313
x = randn (3 )
314
- _, pb = pullback (x -> map (abs2, x), x)
314
+ _, pb = value_and_pullback (x -> map (abs2, x), x)
315
315
Δy = randn (3 )
316
316
@test first (pb ((Δy... , ))) ≈ first (pb (Δy))
317
317
end
318
318
319
319
@testset " empty tuples" begin
320
- out, pb = pullback (map, - , ())
320
+ out, pb = value_and_pullback (map, - , ())
321
321
@test pb (out) === (NoTangent (), NoTangent ())
322
322
323
- out, pb = pullback (map, + , (), ())
324
- # MethodError: reducing over an empty collection is not allowed, ChainRules.var"#map_pullback #1234"{typeof(+), Tuple{Tuple{}, Tuple{}},
323
+ out, pb = value_and_pullback (map, + , (), ())
324
+ # MethodError: reducing over an empty collection is not allowed, ChainRules.var"#map_value_and_pullback #1234"{typeof(+), Tuple{Tuple{}, Tuple{}},
325
325
@test_broken pb (()) === (ZeroTangent (), ZeroTangent (), ZeroTangent ())
326
326
327
327
function build_foo (z)
328
328
foo (x) = x * z
329
329
return foo
330
330
end
331
- out, pb = pullback (map, build_foo (5.0 ), ())
331
+ out, pb = value_and_pullback (map, build_foo (5.0 ), ())
332
332
@test pb (()) === (NoTangent (), NoTangent ())
333
333
end
334
334
335
335
@testset " Vector{Nothing} cotangent" begin
336
336
Δ = fill (ZeroTangent (), 5 )
337
337
338
338
# Unary stateless
339
- out, pb = pullback (map, - , randn (5 ))
339
+ out, pb = value_and_pullback (map, - , randn (5 ))
340
340
@test pb (Δ)[2 ] isa Vector{ZeroTangent}
341
341
342
342
# Binary stateless
343
- out, pb = pullback (map, + , randn (5 ), randn (5 ))
343
+ out, pb = value_and_pullback (map, + , randn (5 ), randn (5 ))
344
344
@test pb (Δ)[2 ] isa Vector{ZeroTangent}
345
345
@test pb (Δ)[3 ] isa Vector{ZeroTangent}
346
346
350
350
return foo
351
351
end
352
352
# AssertionError: Base.issingletontype(typeof(f))
353
- @test_broken out, pb = pullback (map, build_foo (5.0 ), randn (5 ))
353
+ @test_broken out, pb = value_and_pullback (map, build_foo (5.0 ), randn (5 ))
354
354
@test_skip pb (Δ)[2 ] isa Vector{ZeroTangent}
355
355
end
356
356
364
364
(" binary empty vector" , + , Float64[], (Float64[], Float64[])),
365
365
(" binary vector" , + , randn (2 ), (randn (2 ), randn (2 ))),
366
366
]
367
- @inferred pullback (map, f, xs... )
368
- y, pb = pullback (map, f, xs... )
367
+ @inferred value_and_pullback (map, f, xs... )
368
+ y, pb = value_and_pullback (map, f, xs... )
369
369
@inferred pb (ȳ)
370
370
end
371
371
377
377
# return type Tuple{NoTangent, {Union{NoTangent, Tangent{...}}}}
378
378
(" binary tuple" , + , (randn (), randn ()), ((randn (), randn ()), (randn (), randn ()))),
379
379
]
380
- @inferred pullback (map, f, xs... )
381
- y, pb = pullback (map, f, xs... )
380
+ @inferred value_and_pullback (map, f, xs... )
381
+ y, pb = value_and_pullback (map, f, xs... )
382
382
@inferred pb (ȳ)
383
383
end
384
384
end
0 commit comments