@@ -8,10 +8,10 @@ using Statistics: mean, std
8
8
using Random
9
9
# using StatsBase
10
10
11
- gradtest (f, xs:: AbstractArray... ) = gradcheck ((xs... ) -> sum (sin .(f (xs... ))), xs... )
12
- gradtest (f, dims... ) = gradtest (f, rand .(Float64, dims)... )
11
+ gradtest (f, xs:: AbstractArray... ; kw ... ) = gradcheck ((xs... ) -> sum (sin .(f (xs... ))), xs... ; kw ... )
12
+ gradtest (f, dims... ; kw ... ) = gradtest (f, rand .(Float64, dims)... ; kw ... )
13
13
14
- @testset " Tracker " begin # overall testset, rest of the file
14
+ @testset " gradtests 1 " begin
15
15
16
16
@test gradtest ((x, W, b) -> σ .(W* x .+ b), 5 , (2 ,5 ), 2 )
17
17
@test gradtest ((x, W) -> σ .(W* x), 5 , (2 ,5 ))
45
45
@test gradtest (logdet, map ((x) -> x* x' , (rand (4 , 4 ),))[1 ])
46
46
@test gradtest ((x) -> logabsdet (x)[1 ], (4 , 4 ))
47
47
48
+ end # @testset gradtests
49
+
48
50
@testset " indexing & slicing" begin
49
- gradtest (x-> view (x, 1 : 2 , 1 : 2 ), rand (4 , 4 ))
51
+ @test gradtest (x-> view (x, 1 : 2 , 1 : 2 ), rand (4 , 4 ))
50
52
end
51
53
52
54
function promotiontest (f, A, B, C)
53
55
r0 = f (A, B, C)
54
56
r1 = f (param (A), B, C)
55
57
r2 = f (A, param (B), C)
56
- r3 = f (A, B, param (C))
58
+ # r3 = f(A, B, param(C)) # no longer cater to tracked array in 3rd position
57
59
r4 = f (param (A), param (B), param (C))
58
60
59
61
@test ! isa (r0, TrackedArray)
60
- @test all (isa .([r1,r2,r3,r4], TrackedArray))
61
- @test r1 == r2 == r3 == r4
62
+ # @test all(isa.([r1,r2,r3,r4], TrackedArray))
63
+ # @test r1 == r2 == r3 == r4
64
+ @test all (isa .([r1,r2,r4], TrackedArray))
65
+ @test r1 == r2 == r4
62
66
@test r0 == Tracker. data (r4)
63
67
end
64
68
68
72
rvcat (x... ) = reduce (vcat, x)
69
73
rhcat (x... ) = reduce (hcat, x)
70
74
71
- @testset for vcatf in [vcat, cat1, rvcat]
75
+ @testset " 2-arg $vcatf " for vcatf in [vcat, cat1, rvcat]
72
76
@test gradtest (vcatf, rand (5 ), rand (3 ))
73
77
@test gradtest (vcatf, rand (5 ), rand (3 ), rand (8 ))
74
78
@test gradtest (vcatf, rand (5 )' , rand (5 )' )
79
83
end
80
84
81
85
82
- @testset for hcatf in [hcat, cat2, rhcat]
86
+ @testset " 2-arg $hcatf " for hcatf in [hcat, cat2, rhcat]
83
87
@test gradtest (hcatf, rand (5 ), rand (5 ))
84
88
@test gradtest (hcatf, rand (5 )' , rand (5 )' )
85
89
@test gradtest (hcatf, rand (2 ,5 ), rand (2 ,3 ), rand (2 ,8 ))
89
93
@test gradtest (hcatf, rand (5 ), rand (5 ,2 ))
90
94
end
91
95
92
- @testset for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x... ) -> cat (x... , dims = 3 ), (x... ) -> cat (x... , dims = (1 ,2 ))]
96
+ @testset " 1-arg $catf " for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x... ) -> cat (x... , dims = 3 ), (x... ) -> cat (x... , dims = (1 ,2 ))]
93
97
@test gradtest (catf, rand (5 ))
94
98
@test gradtest (catf, rand (5 )' )
95
99
@test gradtest (catf, rand (2 ,5 ))
133
137
@test hcat (1 , param ([1 2 3 ;])) isa TrackedArray
134
138
@test vcat (param (1 ), 2 ) isa TrackedArray
135
139
end
140
+
141
+ @testset " ambiguities" begin
142
+ @test vcat (param ([1 , 2 , 3 ]), [2 ,3 ]) isa TrackedArray
143
+ @test vcat (param ([1 , 2 , 3 ]), [2.0 , 3.0 ]) isa TrackedArray
144
+ @test hcat (param ([1 2 3 ]), [2 , 3 ]' ) isa TrackedArray
145
+ @test hcat (param ([1 2 3 ]), [2.0 , 3.0 ]' ) isa TrackedArray
146
+ end
136
147
137
148
end
138
149
141
152
@test gradtest (x-> x[z], randn (MersenneTwister (123456 ), 3 ))
142
153
end
143
154
155
+ @testset " gradtests 2" begin
156
+
144
157
@test gradtest (x -> permutedims (x, [3 ,1 ,2 ]), rand (4 ,5 ,6 ))
145
158
@test gradtest (x -> PermutedDimsArray (x, [3 ,1 ,2 ]), rand (4 ,5 ,6 ))
146
159
159
172
@test gradtest (kron, rand (5 ,2 ), rand (3 ,2 ), rand (8 ,2 ))
160
173
161
174
@test gradtest (x -> diagm (0 => x), rand (3 ))
175
+ @test gradtest (x -> Matrix (Diagonal (x)), rand (3 ))
162
176
163
177
@test gradtest (W -> inv (log .(W * W)), (5 ,5 ))
164
178
@test gradtest ((A, B) -> A / B , (1 ,5 ), (5 ,5 ))
178
192
gradtest (A -> log .(A * A) \ exp .(B * B), (5 , 5 ))
179
193
end
180
194
195
+ end # @testset "gradtests 2"
196
+
181
197
@testset " mean" begin
182
198
@test gradtest (mean, rand (2 , 3 ))
183
199
208
224
@test gradtest (x -> minimum (x, dims= [1 , 2 ]), rand (2 , 3 , 4 ))
209
225
end
210
226
227
+ @testset " gradtests 3" begin
228
+
211
229
@test gradtest (x -> std (x), rand (5 ,5 ))
212
230
@test gradtest (x -> std (x, dims = 1 ), rand (5 ,5 ))
213
231
@test gradtest (x -> std (x, dims = 1 , corrected = false ), rand (5 ,5 ))
224
242
2 y + x
225
243
end
226
244
245
+ end # @testset "gradtests 3"
246
+
227
247
@testset " transpose" begin
228
248
w = Tracker. TrackedArray (rand (5 ,5 ))
229
249
x = Tracker. TrackedArray (rand (5 ,5 ))
@@ -299,17 +319,15 @@ end
299
319
@test transpose (w)* transpose (x) isa TrackedArray
300
320
end
301
321
302
- @testset " conv" begin
303
- for spatial_rank in (1 , 2 , 3 )
322
+ @testset " conv, $(spatial_rank) d" for spatial_rank in (1 , 2 , 3 )
304
323
x = rand (repeat ([10 ], spatial_rank)... , 3 , 2 )
305
324
w = rand (repeat ([3 ], spatial_rank)... , 3 , 3 )
306
325
cdims = DenseConvDims (x, w)
307
326
@test gradtest ((x, w) -> conv (x, w, cdims), x, w)
308
327
y = conv (x, w, cdims)
309
328
@test gradtest ((y, w) -> ∇conv_data (y, w, cdims), y, w)
310
329
dcdims = DepthwiseConvDims (x, w)
311
- @test gradtest ((x, w) -> depthwiseconv (x, w, dcdims), x, w)
312
- end
330
+ @test_skip gradtest ((x, w) -> depthwiseconv (x, w, dcdims), x, w)
313
331
end
314
332
315
333
@testset " pooling" begin
321
339
end
322
340
end
323
341
324
-
325
342
@test gradtest (x -> Float64 .(x), 5 )
326
343
327
344
@testset " equality & order" begin
480
497
@test size (y) == (5 , 3 )
481
498
end
482
499
483
- end # overall testset
0 commit comments