|
10 | 10 | using Test
|
11 | 11 | using ChainRulesCore
|
12 | 12 | using Diffractor
|
13 |
| -using Distributed: pmap |
| 13 | +using Distributed: CachingPool, pmap, workers |
14 | 14 | using FiniteDifferences
|
15 | 15 | using LinearAlgebra
|
16 | 16 |
|
@@ -43,6 +43,8 @@ jacobicheck(f, dims...) = jacobicheck(f, randn.(Float64, dims)...)
|
43 | 43 | @test jacobicheck(identity, (4,5)) # one random matrix
|
44 | 44 | @test jacobicheck(+, 3, 3) # two random vectors
|
45 | 45 |
|
| 46 | +isZero(x) = x isa AbstractZero |
| 47 | + |
46 | 48 | # Zygote's misnamed hobbit function:
|
47 | 49 | function pullback(f, x...)
|
48 | 50 | y, b = Diffractor.∂⃖{1}()(f, x...)
|
|
161 | 163 | x = rand(3)
|
162 | 164 | z = [1, 2, 3, 3]
|
163 | 165 | y139(x, z) = dot(ones(4), x[z])
|
164 |
| - # Evaluated: ([1.0 0.0 0.0; 1.0 0.0 0.0; 2.0 0.0 0.0], NoTangent()) == ([1, 1, 2], NoTangent()) |
| 166 | + # wrong gradient: Evaluated: ([1.0 0.0 0.0; 1.0 0.0 0.0; 2.0 0.0 0.0], NoTangent()) == ([1, 1, 2], NoTangent()) |
165 | 167 | @test_broken gradient(y139, x, z) == ([1, 1, 2], NoTangent())
|
166 | 168 |
|
167 | 169 | # https://github.com/FluxML/Zygote.jl/issues/376
|
@@ -348,8 +350,301 @@ end
|
348 | 350 | @test_broken out, pb = pullback(map, build_foo(5.0), randn(5))
|
349 | 351 | @test_skip pb(Δ)[2] isa Vector{ZeroTangent}
|
350 | 352 | end
|
| 353 | + |
| 354 | + # Check that map infers correctly. pmap still doesn't infer. |
| 355 | + @testset "map inference" begin |
| 356 | + @testset "$name" for (name, f, ȳ, xs) in [ |
| 357 | + ("unary empty vector", sin, Float64[], (Float64[], )), |
| 358 | + ("unary vector", sin, randn(3), (randn(3), )), |
| 359 | + ("unary empty tuple", sin, (), ((), )), |
| 360 | + ("unary tuple", sin, (randn(), randn()), ((randn(), randn()), )), |
| 361 | + ("binary empty vector", +, Float64[], (Float64[], Float64[])), |
| 362 | + ("binary vector", +, randn(2), (randn(2), randn(2))), |
| 363 | + ] |
| 364 | + @inferred pullback(map, f, xs...) |
| 365 | + y, pb = pullback(map, f, xs...) |
| 366 | + @inferred pb(ȳ) |
| 367 | + end |
| 368 | + |
| 369 | + # these are broken |
| 370 | + @test_skip @testset "$name" for (name, f, ȳ, xs) in [ |
| 371 | + # MethodError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer |
| 372 | + ("binary empty tuple", +, (), ((), ())), |
| 373 | + # return type Tuple{NoTangent, Tangent{...}...} does not match inferred |
| 374 | + # return type Tuple{NoTangent, {Union{NoTangent, Tangent{...}}}} |
| 375 | + ("binary tuple", +, (randn(), randn()), ((randn(), randn()), (randn(), randn()))), |
| 376 | + ] |
| 377 | + @inferred pullback(map, f, xs...) |
| 378 | + y, pb = pullback(map, f, xs...) |
| 379 | + @inferred pb(ȳ) |
| 380 | + end |
| 381 | + end |
| 382 | + |
| 383 | + @testset "map and tuples" begin |
| 384 | + # arrays of tuples, ChainRules's Tangent should not escape |
| 385 | + # MethodError: no method matching one(::Tuple{Int64, Int64}) |
| 386 | + @test_broken gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],) |
| 387 | + T = Tangent{Tuple{Int64, Int64}} |
| 388 | + @test gradient(x -> sum(first, x), [(1,2), (3,4)]) == (T[T(1.0, ZeroTangent()), T(1.0, ZeroTangent())],) |
| 389 | + |
| 390 | + @test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == (Tangent{Tuple{Int,Int,Int}}(1.0, ZeroTangent(), ZeroTangent()),) |
| 391 | + # MethodError: no method matching copy(::Nothing) |
| 392 | + @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) |
| 393 | + @test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) |
| 394 | + |
| 395 | + # mismatched lengths, should zip |
| 396 | + # MethodError: no method matching copy(::Nothing) |
| 397 | + @test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) |
| 398 | + @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),) |
| 399 | + end |
| 400 | + |
| 401 | + @testset "Alternative Pmap Dispatch" begin |
| 402 | + cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1) |
| 403 | + # BoundsError: attempt to access 0-element Core.Compiler.UnitRange{Int64} at index [0] |
| 404 | + @test_broken jacobicheck(xs -> cache_and_map(x -> x^2, xs), rand(2,3)) |
| 405 | + @test_broken jacobicheck((xss...) -> cache_and_map((xs...) -> sqrt(sum(xs.^2)), xss...), [rand(5) for _ in 1:6]...) |
| 406 | + @test_broken jacobicheck(y -> cache_and_map(x->x*y, 1:5), 3) |
| 407 | + end |
| 408 | + |
| 409 | + @testset "Stateful Map" begin |
| 410 | + s = 0 |
| 411 | + f(x) = (s += x) |
| 412 | + # Tuple field type cannot be Union{} |
| 413 | + @test_broken gradient(x -> sum(f.(x)), 1:10) == (10:-1:1,) |
| 414 | + s = 0 |
| 415 | + # MethodError: no method matching copy(::Nothing) |
| 416 | + @test_broken gradient(x -> sum(map(f, x)), 1:10) == (10:-1:1,) |
| 417 | + end |
| 418 | + |
| 419 | + @testset "vararg map" begin |
| 420 | + # early stop |
| 421 | + # MethodError: no method matching length(::InplaceableThunk{...}) |
| 422 | + if VERSION >= v"1.5" |
| 423 | + # In Julia 1.4 and earlier, map(*,rand(5),[1,2,3]) is a DimensionMismatch |
| 424 | + @test_broken gradient(x -> sum(map(*,x,[1,2,3])), rand(5)) == ([1,2,3,0,0],) |
| 425 | + end |
| 426 | + @test_broken gradient(x -> sum(map(*,x,(1,2,3))), rand(5)) == ([1,2,3,0,0],) |
| 427 | + @test_broken gradient(x -> sum(map(*,x,[1,2,3])), Tuple(rand(5))) == ((1.0, 2.0, 3.0, nothing, nothing),) |
| 428 | + |
| 429 | + # mixed shapes |
| 430 | + # MethodError: no method matching length(::InplaceableThunk{...}) |
| 431 | + @test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4], [1 2; 3 4]) == ([1,3,2,4], [1 3; 2 4]) |
| 432 | + @test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3], [1 2; 3 4]) == ([1,3,2], [1 3; 2 0]) |
| 433 | + @test_broken gradient((x,y) -> sum(map(*,x,y)), (1,2,3), [1 2; 3 4]) == ((1,3,2), [1 3; 2 0]) |
| 434 | + @test_broken gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4,5], [1 2; 3 4]) == ([1,3,2,4,0], [1 3; 2 4]) |
| 435 | + @test_broken gradient((x,y) -> sum(map(*,x,y)), (1,2,3,4,5), [1 2; 3 4]) == ((1,3,2,4,nothing), [1 3; 2 4]) |
| 436 | + end |
| 437 | + |
| 438 | + @testset "map: issue 1374" begin |
| 439 | + # https://github.com/FluxML/Zygote.jl/issues/1374 |
| 440 | + struct Affine1374 |
| 441 | + W |
| 442 | + b |
| 443 | + end |
| 444 | + (m::Affine1374)(x) = [sum(x.*r) for r in eachrow(m.W)] + m.b |
| 445 | + m = Affine1374(zeros(3,3), zeros(3,1)) |
| 446 | + x = [ 1.0, 2.0, 3.0] |
| 447 | + y = [-1.0, -2.0, -3.0] |
| 448 | + l1374(y,ŷ) = sum(abs2.(y - ŷ))/2 |
| 449 | + @test_broken gradient(m -> l1374(y,m(x)), m)[1].W ≈ [1 2 3; 2 4 6; 3 6 9] |
| 450 | + end |
351 | 451 | end
|
352 | 452 |
|
| 453 | +@testset "sort" begin |
| 454 | + @test jacobicheck(sort, 5) |
| 455 | + correct = [ |
| 456 | + [2,3,1], |
| 457 | + [1, 2, 3], |
| 458 | + [1,2,3], |
| 459 | + [2,1,3], |
| 460 | + [1,3,2], |
| 461 | + [3,2,1] |
| 462 | + ] |
| 463 | + for i = 1:3 |
| 464 | + @test gradient(v->sort(v)[i], [3.,1,2])[1][correct[1][i]] == 1 |
| 465 | + @test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1 |
| 466 | + end |
| 467 | + for i = 1:3 |
| 468 | + # Rewrite reached intrinsic function bitcast. Missing rule? |
| 469 | + @test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1 |
| 470 | + @test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1 |
| 471 | + @test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1 |
| 472 | + @test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1 |
| 473 | + end |
| 474 | +end |
| 475 | + |
| 476 | +@testset "filter" begin |
| 477 | + @test jacobicheck(xs -> filter(x -> x > 0.5, xs), rand(20)) |
| 478 | + |
| 479 | + @test gradient(x -> sum(log, filter(iseven, x)), 1:10) == |
| 480 | + (map(x -> iseven(x) ? 1/x : 0, 1:10),) |
| 481 | + @test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) == |
| 482 | + (map(x -> iseven(x) ? 2x : 0, 1:10),) |
| 483 | + # (map(x -> iseven(x) ? 2x+2im : 0, 1:10),) |
| 484 | +end |
| 485 | + |
| 486 | +@testset "maximum" begin |
| 487 | + @test jacobicheck(maximum, rand(2, 3)) |
| 488 | + |
| 489 | + # MethodError: no method matching copy(::Nothing) |
| 490 | + @test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3)) |
| 491 | + @test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4)) |
| 492 | + @test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) |
| 493 | + |
| 494 | + @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] |
| 495 | +end |
| 496 | + |
| 497 | +@testset "minimum" begin |
| 498 | + @test jacobicheck(minimum, rand(2, 3)) |
| 499 | + |
| 500 | + # MethodError: no method matching copy(::Nothing) |
| 501 | + @test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3)) |
| 502 | + @test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3)) |
| 503 | +end |
| 504 | + |
| 505 | +@testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72 |
| 506 | + # TypeError: in typeassert, expected Int64, got a value of type Nothing |
| 507 | + @test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2)) |
| 508 | + @test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3)) |
| 509 | +end |
| 510 | + |
| 511 | +@testset "vcat" begin |
| 512 | + # Scalar |
| 513 | + @test gradient((x,y) -> sum(vcat(x,y)), 1,2) == (1,1) |
| 514 | + @test gradient((x,y) -> sum([x; y; x]), 1,2) == (2,1) |
| 515 | + |
| 516 | + # Scalar + Vector |
| 517 | + @test gradient(x -> sum(vcat(x, 1, x)), rand(3)) == ([2,2,2],) |
| 518 | + @test gradient((x,y) -> sum(vcat(x, y, y)), rand(3), 4) == ([1,1,1], 2) |
| 519 | + |
| 520 | + # Vector-only. |
| 521 | + @test jacobicheck(vcat, randn(10)) |
| 522 | + @test jacobicheck(x -> vcat(x, [1,2,3], x), randn(3)) |
| 523 | + |
| 524 | + # Matrix-Vector |
| 525 | + @test jacobicheck(x-> vcat(x, [1,2,3]), rand(2,1)) |
| 526 | + @test jacobicheck(x-> vcat(x, ones(3,1)), rand(2)) |
| 527 | +end |
| 528 | + |
| 529 | +@testset "hcat" begin |
| 530 | + # Scalar |
| 531 | + @test gradient((x,y) -> sum(hcat(x,y)), 1,2) == (1,1) |
| 532 | + @test gradient((x,y) -> sum([x y]), 1,2) == (1,1) |
| 533 | + @test gradient((a,b,c,d) -> sum(sqrt, [a b;c d]), 1,1,1,4) == (0.5, 0.5, 0.5, 0.25) |
| 534 | + |
| 535 | + # Vector-only |
| 536 | + @test jacobicheck(hcat, rand(3)) |
| 537 | + @test jacobicheck(x -> hcat(x, [1,2,3]), rand(3)) |
| 538 | + |
| 539 | + # Matrix-only |
| 540 | + @test jacobicheck(hcat, rand(3,4)) |
| 541 | + @test jacobicheck(x -> hcat(x, [1 2; 3 4], x), rand(2,2)) |
| 542 | + |
| 543 | + # Matrix-Scalar |
| 544 | + @test gradient((x,y) -> sum(hcat(x, y)), 1, [2 3 4]) == (1, [1 1 1]) |
| 545 | + @test gradient(x -> sum(hcat(1, x, 2)), transpose([3,4,5]))[1] isa Transpose |
| 546 | + @test gradient(x -> sum(hcat(1, x, 2)), [3,4,5]')[1] isa Adjoint |
| 547 | +end |
| 548 | + |
| 549 | +@testset "hvcat" begin |
| 550 | + @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0] |
| 551 | + @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0] |
| 552 | + @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0] |
| 553 | + @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1] |
| 554 | + # https://github.com/FluxML/Zygote.jl/issues/513 |
| 555 | + @test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,) |
| 556 | +end |
| 557 | + |
| 558 | +@testset "cat(...; dims = $dim)" for dim in 1:3 |
| 559 | + # Rewrite reached intrinsic function bitcast. Missing rule? |
| 560 | + |
| 561 | + catdim = (x...) -> cat(x..., dims = dim) |
| 562 | + @test_broken jacobicheck(catdim, rand(4,1)) |
| 563 | + @test_broken jacobicheck(catdim, rand(5), rand(5,1)) |
| 564 | + @test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5)) |
| 565 | + |
| 566 | + catdimval = (x...) -> cat(x...; dims = Val(dim)) |
| 567 | + @test_broken jacobicheck(catdimval, rand(5), rand(5)) |
| 568 | + @test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1)) |
| 569 | + |
| 570 | + # one empty |
| 571 | + dim == 1 || continue |
| 572 | + @test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3)) |
| 573 | +end |
| 574 | + |
| 575 | +@testset "one(s) and zero(s)" begin |
| 576 | + # should these be ZeroTangent or NoTangent? |
| 577 | + @test gradient(x->sum(ones(size(x))), randn(5))[1] === NoTangent() |
| 578 | + @test_broken gradient(x->sum(one(x)), randn(3, 3))[1] === NoTangent() |
| 579 | + @test gradient(x->sum(zeros(size(x))), randn(7))[1] === NoTangent() |
| 580 | + @test_broken gradient(x->sum(zero(x)), randn(3))[1] === NoTangent() |
| 581 | +end |
| 582 | + |
| 583 | +@testset "fma and muladd" begin |
| 584 | + @test gradcheck(x -> fma(x[1], x[2], x[3]), [2.0, 3.0, 5.0]) |
| 585 | + @test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0]) |
| 586 | +end |
| 587 | + |
| 588 | +@testset "broadcast" begin |
| 589 | + @test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1] |
| 590 | + |
| 591 | + # mixing arrays & Ref(array) |
| 592 | + a = rand(3) |
| 593 | + b = rand(2,2) |
| 594 | + @test jacobicheck(x -> sum(diag.((x,) .* a)), b) |
| 595 | + @test jacobicheck(x -> sum(diag.(Ref(x) .* a)), b) |
| 596 | + @test jacobicheck(x -> sum(diag.([x] .* a)), b) |
| 597 | + |
| 598 | + # tests for https://github.com/FluxML/Zygote.jl/issues/724 |
| 599 | + x1 = rand(3, 3) |
| 600 | + @test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero |
| 601 | + # MethodError: no method matching copy(::Nothing) |
| 602 | + @test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1)) |
| 603 | + |
| 604 | + # tests for un-broadcasting *, / via scalar rules |
| 605 | + @test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3)) |
| 606 | + @test all(gradient((x,y) -> sum(x .* y), 5, [1,2]) .≈ (3, [5, 5])) |
| 607 | + @test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3])) |
| 608 | + @test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12)) |
| 609 | + |
| 610 | + @test_skip begin |
| 611 | + using SparseArrays # not loaded at present |
| 612 | + # https://github.com/FluxML/Zygote.jl/pull/1171 |
| 613 | + sm = sprand(5, 5, 0.5) |
| 614 | + @test gradient(x -> sum(abs2, Float32.(x)), sm)[1] ≈ gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1] |
| 615 | + @test_broken gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64} # MethodError: no method matching zero(::Type{Any}), in ProjectTo(xs::SparseMatrixCSC{Any, Int64}) |
| 616 | + end |
| 617 | + |
| 618 | + # https://github.com/FluxML/Zygote.jl/issues/1178 |
| 619 | + function f1179(x) |
| 620 | + fs = Ref.(x) |
| 621 | + getindex.(fs) |
| 622 | + end |
| 623 | + # wrong gradient: Evaluated: ([1.0, 1.0],) == ([2.0, 2.0],) |
| 624 | + @test_broken gradient(sum∘f1179, ones(2)) == ([2.0, 2.0],) # MethodError: no method matching one(::Base.RefValue{Float64}) |
| 625 | +end |
| 626 | + |
| 627 | +@testset "array +,-" begin |
| 628 | + A, B = randn(3, 4, 5), randn(3, 4, 5) |
| 629 | + @test jacobicheck(+, B) |
| 630 | + @test jacobicheck(+, A, B) |
| 631 | + # wrong gradient |
| 632 | + @test_broken jacobicheck(+, A, B, A) |
| 633 | + @test jacobicheck(-, A) |
| 634 | + # in typeassert, expected Int64, got a value of type Nothing |
| 635 | + @test_broken jacobicheck(-, A, B) |
| 636 | +end |
| 637 | + |
| 638 | + |
| 639 | + |
| 640 | + |
| 641 | + |
| 642 | + |
| 643 | + |
| 644 | + |
| 645 | + |
| 646 | + |
| 647 | + |
353 | 648 |
|
354 | 649 |
|
355 | 650 |
|
|
375 | 670 |
|
376 | 671 | # FIXME: complex numbers; put somewhere
|
377 | 672 | @test gradcheck((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0])
|
| 673 | +#@testset "$f(::AbstractArray)" for f in (real, conj, imag) |
| 674 | +# rng, N = MersenneTwister(123456), 3 |
| 675 | +# Ts = (Float64, ComplexF64) |
| 676 | +# @testset "$f(::Array{$IT})" for IT in Ts |
| 677 | +# A = randn(IT, N, N) |
| 678 | +# y, back = Zygote.pullback(f, A) |
| 679 | +# y2, back2 = Zygote.pullback(x->f.(x), A) |
| 680 | +# @test y == y2 |
| 681 | +# @testset "back(::Array{$BT})" for BT in Ts |
| 682 | +# ȳ = randn(BT, N, N) |
| 683 | +# @test back(ȳ)[1] == back2(ȳ)[1] |
| 684 | +# end |
| 685 | +# end |
| 686 | +#end |
| 687 | + |
378 | 688 |
|
379 | 689 | # FIXME: misc tests
|
380 | 690 | @test jacobicheck(x -> x', rand(5))
|
|
0 commit comments