1
1
2
- using ChainRules: tuplecast, unzip # tuplemap,
2
+ using ChainRules: tuplecast, unzip, tuplemap
3
3
4
4
@testset " tuplecast.jl" begin
5
- @testset " basics: $(sprint (show, fun)) " for fun in [tuplecast, unzip ∘ broadcast] # [ tuplemap, tuplecast, unzip∘map, unzip∘broadcast]
5
+ @testset " basics: $(sprint (show, fun)) " for fun in [tuplemap, tuplecast, unzip∘ map, unzip∘ broadcast]
6
6
@test_throws Exception fun (sqrt, 1 : 3 )
7
7
8
8
@test fun (tuple, 1 : 3 , 4 : 6 ) == ([1 , 2 , 3 ], [4 , 5 , 6 ])
@@ -16,32 +16,69 @@ using ChainRules: tuplecast, unzip # tuplemap,
16
16
else
17
17
@test fun (tuple, [1 ,2 ,3 ], [4 5 ]) == ([1 1 ; 2 2 ; 3 3 ], [4 5 ; 4 5 ; 4 5 ])
18
18
end
19
+
20
+ if fun == tuplemap
21
+ @test_broken fun (tuple, (1 ,2 ,3 ), (4 ,5 ,6 )) == ((1 , 2 , 3 ), (4 , 5 , 6 ))
22
+ elseif fun == unzip∘ map
23
+ @test fun (tuple, (1 ,2 ,3 ), (4 ,5 ,6 )) == ((1 , 2 , 3 ), (4 , 5 , 6 ))
24
+ else
25
+ @test fun (tuple, (1 ,2 ,3 ), (4 ,5 ,6 )) == ((1 , 2 , 3 ), (4 , 5 , 6 ))
26
+ @test fun (tuple, (1 ,2 ,3 ), (7 ,)) == ((1 , 2 , 3 ), (7 , 7 , 7 ))
27
+ @test fun (tuple, (1 ,2 ,3 ), 8 ) == ((1 , 2 , 3 ), (8 , 8 , 8 ))
28
+ end
29
+ @test fun (tuple, (1 ,2 ,3 ), [4 ,5 ,6 ]) == ([1 , 2 , 3 ], [4 , 5 , 6 ]) # mix tuple & vector
19
30
end
31
+
32
+ @testset " rrules" begin
33
+ # These exist to allow for second derivatives
20
34
21
- # tuplemap(tuple, (1,2,3), (4,5,6)) == ([1, 2, 3], [4, 5, 6])
35
+ # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], collectheck_inferred=false) # return type Tuple{NoTangent, NoTangent, Vector{Float64}, Vector{Float64}} does not match inferred return type NTuple{4, Any}
36
+
37
+ y1, bk1 = rrule (CFG, tuplecast, tuple, [1 ,2 ,3.0 ], [4 ,5 ,6.0 ])
38
+ @test y1 == ([1 , 2 , 3 ], [4 , 5 , 6 ])
39
+ @test bk1 (([1 ,10 ,100.0 ], [7 ,8 ,9.0 ]))[3 ] ≈ [1 ,10 ,100 ]
40
+
41
+ # bk1(([1,10,100.0], NoTangent())) # DimensionMismatch in FiniteDifferences
42
+
43
+ y2, bk2 = rrule (CFG, tuplecast, tuple, [1 ,2 ,3.0 ], [4 5.0 ], 6.0 )
44
+ @test y2 == ([1 1 ; 2 2 ; 3 3 ], [4 5 ; 4 5 ; 4 5 ], [6 6 ; 6 6 ; 6 6 ])
45
+ @test bk2 (y2)[5 ] ≈ 36
22
46
47
+ y4, bk4 = rrule (CFG, tuplemap, tuple, [1 ,2 ,3.0 ], [4 ,5 ,6.0 ])
48
+ @test y4 == ([1 , 2 , 3 ], [4 , 5 , 6 ])
49
+ @test bk4 (([1 ,10 ,100.0 ], [7 ,8 ,9.0 ]))[3 ] ≈ [1 ,10 ,100 ]
50
+ end
51
+
23
52
@testset " unzip" begin
24
53
@test unzip ([(1 ,2 ), (3 ,4 ), (5 ,6 )]) == ([1 , 3 , 5 ], [2 , 4 , 6 ])
54
+ @test unzip (Any[(1 ,2 ), (3 ,4 ), (5 ,6 )]) == ([1 , 3 , 5 ], [2 , 4 , 6 ])
55
+
25
56
@test unzip ([(nothing ,2 ), (3 ,4 ), (5 ,6 )]) == ([nothing , 3 , 5 ], [2 , 4 , 6 ])
26
57
@test unzip ([(missing ,2 ), (missing ,4 ), (missing ,6 )])[2 ] isa Base. ReinterpretArray
27
58
59
+ @test unzip ([(1 ,), (3 ,), (5 ,)]) == ([1 , 3 , 5 ],)
60
+ @test unzip ([(1 ,), (3 ,), (5 ,)])[1 ] isa Base. ReinterpretArray
61
+
62
+ @test unzip (((1 ,2 ), (3 ,4 ), (5 ,6 ))) == ((1 , 3 , 5 ), (2 , 4 , 6 ))
63
+
64
+ # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2
65
+
28
66
y, bk = rrule (unzip, [(1 ,2 ), (3 ,4 ), (5 ,6 )])
29
67
@test y == ([1 , 3 , 5 ], [2 , 4 , 6 ])
30
68
@test bk (Tangent {Tuple} ([1 ,1 ,1 ], [10 ,100 ,1000 ]))[2 ] isa Vector{<: Tangent{<:Tuple} }
31
- end
32
-
33
- @testset " rrules" begin
34
- # These exist to allow for second derivatives
35
69
36
- # test_rrule(collect∘tuplecast, tuple, [1,2,3.], [4,5,6.], check_inferred=false)
37
- y1, bk1 = rrule (CFG, tuplecast, tuple, [1 ,2 ,3.0 ], [4 ,5 ,6.0 ])
38
- @test y1 == ([1 , 2 , 3 ], [4 , 5 , 6 ])
39
- @test bk1 (([1 ,10 ,100.0 ], [7 ,8 ,9.0 ]))[3 ] ≈ [1 ,10 ,100 ]
70
+ y3, bk3 = rrule (unzip, [(1 ,ZeroTangent ()), (3 ,ZeroTangent ()), (5 ,ZeroTangent ())])
71
+ @test y3 == ([1 , 3 , 5 ], [ZeroTangent (), ZeroTangent (), ZeroTangent ()])
72
+ dx3 = bk3 (Tangent {Tuple} ([1 ,1 ,1 ], [10 ,100 ,1000 ]))[2 ]
73
+ @test dx3 isa Vector{<: Tangent{<:Tuple} }
74
+ @test Tuple (dx3[1 ]) == (1.0 , NoTangent ())
40
75
41
- y2, bk2 = rrule (CFG, tuplecast, tuple, [1 ,2 ,3.0 ], [4 5.0 ], 6.0 )
42
- @test y2 == ([1 1 ; 2 2 ; 3 3 ], [4 5 ; 4 5 ; 4 5 ], [6 6 ; 6 6 ; 6 6 ])
43
- @test bk2 (y2)[5 ] ≈ 36
44
-
45
- test_rrule (unzip, [(1.0 , 2.0 ), (3.0 , 4.0 ), (5.0 , 6.0 )], check_inferred= false )
76
+ y5, bk5 = rrule (unzip, ((1 ,2 ), (3 ,4 ), (5 ,6 )))
77
+ @test y5 == ((1 , 3 , 5 ), (2 , 4 , 6 ))
78
+ @test bk5 (y5)[2 ] isa Tangent{<: Tuple }
79
+ @test Tuple (bk5 (y5)[2 ][2 ]) == (3 , 4 )
80
+ dx5 = bk5 (((1 ,10 ,100 ), ZeroTangent ()))
81
+ @test dx5[2 ] isa Tangent{<: Tuple }
82
+ @test Tuple (dx5[2 ][2 ]) == (10 , ZeroTangent ())
46
83
end
47
84
end
0 commit comments