Skip to content

Commit a514f7d

Browse files
committed
add some more testsets
1 parent 6b4579b commit a514f7d

File tree

1 file changed

+312
-2
lines changed

1 file changed

+312
-2
lines changed

test/gradcheck.jl

Lines changed: 312 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
using Test
1111
using ChainRulesCore
1212
using Diffractor
13-
using Distributed: pmap
13+
using Distributed: CachingPool, pmap, workers
1414
using FiniteDifferences
1515
using LinearAlgebra
1616

@@ -43,6 +43,8 @@ jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...)
4343
@test jacobicheck(identity, (4,5)) # one random matrix
4444
@test jacobicheck(+, 3, 3) # two random vectors
4545

46+
isZero(x) = x isa AbstractZero
47+
4648
# Zygote's misnamed hobbit function:
4749
function pullback(f, x...)
4850
y, b = Diffractor.∂⃖{1}()(f, x...)
@@ -161,7 +163,7 @@ end
161163
x = rand(3)
162164
z = [1, 2, 3, 3]
163165
y139(x, z) = dot(ones(4), x[z])
164-
# Evaluated: ([1.0 0.0 0.0; 1.0 0.0 0.0; 2.0 0.0 0.0], NoTangent()) == ([1, 1, 2], NoTangent())
166+
# wrong gradient: Evaluated: ([1.0 0.0 0.0; 1.0 0.0 0.0; 2.0 0.0 0.0], NoTangent()) == ([1, 1, 2], NoTangent())
165167
@test_broken gradient(y139, x, z) == ([1, 1, 2], NoTangent())
166168

167169
# https://github.com/FluxML/Zygote.jl/issues/376
@@ -348,8 +350,301 @@ end
348350
@test_broken out, pb = pullback(map, build_foo(5.0), randn(5))
349351
@test_skip pb(Δ)[2] isa Vector{ZeroTangent}
350352
end
353+
354+
# Check that map infers correctly. pmap still doesn't infer.
355+
@testset "map inference" begin
356+
@testset "$name" for (name, f, ȳ, xs) in [
357+
("unary empty vector", sin, Float64[], (Float64[], )),
358+
("unary vector", sin, randn(3), (randn(3), )),
359+
("unary empty tuple", sin, (), ((), )),
360+
("unary tuple", sin, (randn(), randn()), ((randn(), randn()), )),
361+
("binary empty vector", +, Float64[], (Float64[], Float64[])),
362+
("binary vector", +, randn(2), (randn(2), randn(2))),
363+
]
364+
@inferred pullback(map, f, xs...)
365+
y, pb = pullback(map, f, xs...)
366+
@inferred pb(ȳ)
367+
end
368+
369+
# these are broken
370+
@test_skip @testset "$name" for (name, f, ȳ, xs) in [
371+
# MethodError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer
372+
("binary empty tuple", +, (), ((), ())),
373+
# return type Tuple{NoTangent, Tangent{...}...} does not match inferred
374+
# return type Tuple{NoTangent, {Union{NoTangent, Tangent{...}}}}
375+
("binary tuple", +, (randn(), randn()), ((randn(), randn()), (randn(), randn()))),
376+
]
377+
@inferred pullback(map, f, xs...)
378+
y, pb = pullback(map, f, xs...)
379+
@inferred pb(ȳ)
380+
end
381+
end
382+
383+
@testset "map and tuples" begin
384+
# arrays of tuples, ChainRules's Tangent should not escape
385+
# MethodError: no method matching one(::Tuple{Int64, Int64})
386+
@test_broken gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],)
387+
T = Tangent{Tuple{Int64, Int64}}
388+
@test gradient(x -> sum(first, x), [(1,2), (3,4)]) == (T[T(1.0, ZeroTangent()), T(1.0, ZeroTangent())],)
389+
390+
@test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == (Tangent{Tuple{Int,Int,Int}}(1.0, ZeroTangent(), ZeroTangent()),)
391+
# MethodError: no method matching copy(::Nothing)
392+
@test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
393+
@test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],)
394+
395+
# mismatched lengths, should zip
396+
# MethodError: no method matching copy(::Nothing)
397+
@test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
398+
@test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),)
399+
end
400+
401+
@testset "Alternative Pmap Dispatch" begin
402+
cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1)
403+
# BoundsError: attempt to access 0-element Core.Compiler.UnitRange{Int64} at index [0]
404+
@test_broken jacobicheck(xs -> cache_and_map(x -> x^2, xs), rand(2,3))
405+
@test_broken jacobicheck((xss...) -> cache_and_map((xs...) -> sqrt(sum(xs.^2)), xss...), [rand(5) for _ in 1:6]...)
406+
@test_broken jacobicheck(y -> cache_and_map(x->x*y, 1:5), 3)
407+
end
408+
409+
@testset "Stateful Map" begin
410+
s = 0
411+
f(x) = (s += x)
412+
# Tuple field type cannot be Union{}
413+
@test_broken gradient(x -> sum(f.(x)), 1:10) == (10:-1:1,)
414+
s = 0
415+
# MethodError: no method matching copy(::Nothing)
416+
@test_broken gradient(x -> sum(map(f, x)), 1:10) == (10:-1:1,)
417+
end
418+
419+
@testset "vararg map" begin
420+
# early stop
421+
# MethodError: no method matching length(::InplaceableThunk{...})
422+
if VERSION >= v"1.5"
423+
# In Julia 1.4 and earlier, map(*,rand(5),[1,2,3]) is a DimensionMismatch
424+
@test_broken gradient(x -> sum(map(*,x,[1,2,3])), rand(5)) == ([1,2,3,0,0],)
425+
end
426+
@test_broken gradient(x -> sum(map(*,x,(1,2,3))), rand(5)) == ([1,2,3,0,0],)
427+
@test_broken gradient(x -> sum(map(*,x,[1,2,3])), Tuple(rand(5))) == ((1.0, 2.0, 3.0, nothing, nothing),)
428+
429+
# mixed shapes
430+
# MethodError: no method matching length(::InplaceableThunk{...})
431+
@test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4], [1 2; 3 4]) == ([1,3,2,4], [1 3; 2 4])
432+
@test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3], [1 2; 3 4]) == ([1,3,2], [1 3; 2 0])
433+
@test_broken gradient((x,y) -> sum(map(*,x,y)), (1,2,3), [1 2; 3 4]) == ((1,3,2), [1 3; 2 0])
434+
@test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4,5], [1 2; 3 4]) == ([1,3,2,4,0], [1 3; 2 4])
435+
@test_broken gradient((x,y) -> sum(map(*,x,y)), (1,2,3,4,5), [1 2; 3 4]) == ((1,3,2,4,nothing), [1 3; 2 4])
436+
end
437+
438+
@testset "map: issue 1374" begin
439+
# https://github.com/FluxML/Zygote.jl/issues/1374
440+
struct Affine1374
441+
W
442+
b
443+
end
444+
(m::Affine1374)(x) = [sum(x.*r) for r in eachrow(m.W)] + m.b
445+
m = Affine1374(zeros(3,3), zeros(3,1))
446+
x = [ 1.0, 2.0, 3.0]
447+
y = [-1.0, -2.0, -3.0]
448+
l1374(y,ŷ) = sum(abs2.(y - ŷ))/2
449+
@test_broken gradient(m -> l1374(y,m(x)), m)[1].W [1 2 3; 2 4 6; 3 6 9]
450+
end
351451
end
352452

453+
@testset "sort" begin
454+
@test jacobicheck(sort, 5)
455+
correct = [
456+
[2,3,1],
457+
[1, 2, 3],
458+
[1,2,3],
459+
[2,1,3],
460+
[1,3,2],
461+
[3,2,1]
462+
]
463+
for i = 1:3
464+
@test gradient(v->sort(v)[i], [3.,1,2])[1][correct[1][i]] == 1
465+
@test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1
466+
end
467+
for i = 1:3
468+
# Rewrite reached intrinsic function bitcast. Missing rule?
469+
@test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
470+
@test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
471+
@test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
472+
@test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
473+
end
474+
end
475+
476+
@testset "filter" begin
477+
@test jacobicheck(xs -> filter(x -> x > 0.5, xs), rand(20))
478+
479+
@test gradient(x -> sum(log, filter(iseven, x)), 1:10) ==
480+
(map(x -> iseven(x) ? 1/x : 0, 1:10),)
481+
@test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) ==
482+
(map(x -> iseven(x) ? 2x : 0, 1:10),)
483+
# (map(x -> iseven(x) ? 2x+2im : 0, 1:10),)
484+
end
485+
486+
@testset "maximum" begin
487+
@test jacobicheck(maximum, rand(2, 3))
488+
489+
# MethodError: no method matching copy(::Nothing)
490+
@test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3))
491+
@test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4))
492+
@test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
493+
494+
@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]
495+
end
496+
497+
@testset "minimum" begin
498+
@test jacobicheck(minimum, rand(2, 3))
499+
500+
# MethodError: no method matching copy(::Nothing)
501+
@test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3))
502+
@test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3))
503+
end
504+
505+
@testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72
506+
# TypeError: in typeassert, expected Int64, got a value of type Nothing
507+
@test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2))
508+
@test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3))
509+
end
510+
511+
@testset "vcat" begin
512+
# Scalar
513+
@test gradient((x,y) -> sum(vcat(x,y)), 1,2) == (1,1)
514+
@test gradient((x,y) -> sum([x; y; x]), 1,2) == (2,1)
515+
516+
# Scalar + Vector
517+
@test gradient(x -> sum(vcat(x, 1, x)), rand(3)) == ([2,2,2],)
518+
@test gradient((x,y) -> sum(vcat(x, y, y)), rand(3), 4) == ([1,1,1], 2)
519+
520+
# Vector-only.
521+
@test jacobicheck(vcat, randn(10))
522+
@test jacobicheck(x -> vcat(x, [1,2,3], x), randn(3))
523+
524+
# Matrix-Vector
525+
@test jacobicheck(x-> vcat(x, [1,2,3]), rand(2,1))
526+
@test jacobicheck(x-> vcat(x, ones(3,1)), rand(2))
527+
end
528+
529+
@testset "hcat" begin
530+
# Scalar
531+
@test gradient((x,y) -> sum(hcat(x,y)), 1,2) == (1,1)
532+
@test gradient((x,y) -> sum([x y]), 1,2) == (1,1)
533+
@test gradient((a,b,c,d) -> sum(sqrt, [a b;c d]), 1,1,1,4) == (0.5, 0.5, 0.5, 0.25)
534+
535+
# Vector-only
536+
@test jacobicheck(hcat, rand(3))
537+
@test jacobicheck(x -> hcat(x, [1,2,3]), rand(3))
538+
539+
# Matrix-only
540+
@test jacobicheck(hcat, rand(3,4))
541+
@test jacobicheck(x -> hcat(x, [1 2; 3 4], x), rand(2,2))
542+
543+
# Matrix-Scalar
544+
@test gradient((x,y) -> sum(hcat(x, y)), 1, [2 3 4]) == (1, [1 1 1])
545+
@test gradient(x -> sum(hcat(1, x, 2)), transpose([3,4,5]))[1] isa Transpose
546+
@test gradient(x -> sum(hcat(1, x, 2)), [3,4,5]')[1] isa Adjoint
547+
end
548+
549+
@testset "hvcat" begin
550+
@test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0]
551+
@test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0]
552+
@test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0]
553+
@test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1]
554+
# https://github.com/FluxML/Zygote.jl/issues/513
555+
@test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,)
556+
end
557+
558+
@testset "cat(...; dims = $dim)" for dim in 1:3
559+
# Rewrite reached intrinsic function bitcast. Missing rule?
560+
561+
catdim = (x...) -> cat(x..., dims = dim)
562+
@test_broken jacobicheck(catdim, rand(4,1))
563+
@test_broken jacobicheck(catdim, rand(5), rand(5,1))
564+
@test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5))
565+
566+
catdimval = (x...) -> cat(x...; dims = Val(dim))
567+
@test_broken jacobicheck(catdimval, rand(5), rand(5))
568+
@test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1))
569+
570+
# one empty
571+
dim == 1 || continue
572+
@test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3))
573+
end
574+
575+
@testset "one(s) and zero(s)" begin
576+
# should these be ZeroTangent or NoTangent?
577+
@test gradient(x->sum(ones(size(x))), randn(5))[1] === NoTangent()
578+
@test_broken gradient(x->sum(one(x)), randn(3, 3))[1] === NoTangent()
579+
@test gradient(x->sum(zeros(size(x))), randn(7))[1] === NoTangent()
580+
@test_broken gradient(x->sum(zero(x)), randn(3))[1] === NoTangent()
581+
end
582+
583+
@testset "fma and muladd" begin
584+
@test gradcheck(x -> fma(x[1], x[2], x[3]), [2.0, 3.0, 5.0])
585+
@test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0])
586+
end
587+
588+
@testset "broadcast" begin
589+
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] [1 0 0; 0 0 0; 0 0 -1]
590+
591+
# mixing arrays & Ref(array)
592+
a = rand(3)
593+
b = rand(2,2)
594+
@test jacobicheck(x -> sum(diag.((x,) .* a)), b)
595+
@test jacobicheck(x -> sum(diag.(Ref(x) .* a)), b)
596+
@test jacobicheck(x -> sum(diag.([x] .* a)), b)
597+
598+
# tests for https://github.com/FluxML/Zygote.jl/issues/724
599+
x1 = rand(3, 3)
600+
@test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero
601+
# MethodError: no method matching copy(::Nothing)
602+
@test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))
603+
604+
# tests for un-broadcasting *, / via scalar rules
605+
@test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3))
606+
@test all(gradient((x,y) -> sum(x .* y), 5, [1,2]) .≈ (3, [5, 5]))
607+
@test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
608+
@test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))
609+
610+
@test_skip begin
611+
using SparseArrays # not loaded at present
612+
# https://github.com/FluxML/Zygote.jl/pull/1171
613+
sm = sprand(5, 5, 0.5)
614+
@test gradient(x -> sum(abs2, Float32.(x)), sm)[1] gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1]
615+
@test_broken gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64} # MethodError: no method matching zero(::Type{Any}), in ProjectTo(xs::SparseMatrixCSC{Any, Int64})
616+
end
617+
618+
# https://github.com/FluxML/Zygote.jl/issues/1178
619+
function f1179(x)
620+
fs = Ref.(x)
621+
getindex.(fs)
622+
end
623+
# wrong gradient: Evaluated: ([1.0, 1.0],) == ([2.0, 2.0],)
624+
@test_broken gradient(sumf1179, ones(2)) == ([2.0, 2.0],) # MethodError: no method matching one(::Base.RefValue{Float64})
625+
end
626+
627+
@testset "array +,-" begin
628+
A, B = randn(3, 4, 5), randn(3, 4, 5)
629+
@test jacobicheck(+, B)
630+
@test jacobicheck(+, A, B)
631+
# wrong gradient
632+
@test_broken jacobicheck(+, A, B, A)
633+
@test jacobicheck(-, A)
634+
# in typeassert, expected Int64, got a value of type Nothing
635+
@test_broken jacobicheck(-, A, B)
636+
end
637+
638+
639+
640+
641+
642+
643+
644+
645+
646+
647+
353648

354649

355650

@@ -375,6 +670,21 @@ end
375670

376671
# FIXME: complex numbers; put somewhere
377672
@test gradcheck((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0])
673+
#@testset "$f(::AbstractArray)" for f in (real, conj, imag)
674+
# rng, N = MersenneTwister(123456), 3
675+
# Ts = (Float64, ComplexF64)
676+
# @testset "$f(::Array{$IT})" for IT in Ts
677+
# A = randn(IT, N, N)
678+
# y, back = Zygote.pullback(f, A)
679+
# y2, back2 = Zygote.pullback(x->f.(x), A)
680+
# @test y == y2
681+
# @testset "back(::Array{$BT})" for BT in Ts
682+
# ȳ = randn(BT, N, N)
683+
# @test back(ȳ)[1] == back2(ȳ)[1]
684+
# end
685+
# end
686+
#end
687+
378688

379689
# FIXME: misc tests
380690
@test jacobicheck(x -> x', rand(5))

0 commit comments

Comments
 (0)