68
68
69
69
@testset " reshape" begin
70
70
# Forward
71
- test_frule (reshape, rand (4 , 3 ), 2 , :)
71
+ @gpu test_frule (reshape, rand (4 , 3 ), 2 , :)
72
72
test_frule (reshape, rand (4 , 3 ), axes (rand (6 , 2 )))
73
73
@test_skip test_frule (reshape, Diagonal (rand (4 )), 2 , :) # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/239
74
74
75
75
# Reverse
76
- test_rrule (reshape, rand (4 , 5 ), (2 , 10 ))
76
+ @gpu test_rrule (reshape, rand (4 , 5 ), (2 , 10 ))
77
77
test_rrule (reshape, rand (4 , 5 ), 2 , 10 )
78
78
test_rrule (reshape, rand (4 , 5 ), 2 , :)
79
79
test_rrule (reshape, rand (4 , 5 ), axes (rand (10 , 2 )))
98
98
99
99
@testset " permutedims + PermutedDimsArray" begin
100
100
# Forward
101
- test_frule (permutedims, rand (5 ))
102
- test_frule (permutedims, rand (3 , 4 ), (2 , 1 ))
101
+ @gpu test_frule (permutedims, rand (5 ))
102
+ @gpu test_frule (permutedims, rand (3 , 4 ), (2 , 1 ))
103
103
test_frule (permutedims!, rand (4 ,3 ), rand (3 , 4 ), (2 , 1 ))
104
104
test_frule (PermutedDimsArray, rand (3 , 4 , 5 ), (3 , 1 , 2 ))
105
105
106
106
# Reverse
107
- test_rrule (permutedims, rand (5 ))
108
- test_rrule (permutedims, rand (3 , 4 ), (2 , 1 ))
107
+ @gpu test_rrule (permutedims, rand (5 ))
108
+ @gpu test_rrule (permutedims, rand (3 , 4 ), (2 , 1 ))
109
109
test_rrule (permutedims, Diagonal (rand (5 )), (2 , 1 ))
110
110
# Note BTW that permutedims(Diagonal(rand(5))) does not use the rule at all
111
111
@@ -127,12 +127,12 @@ end
127
127
test_rrule (repeat, rand (4 , ))
128
128
test_rrule (repeat, rand (4 , 5 ))
129
129
test_rrule (repeat, rand (4 , 5 ); fkwargs = (outer= (1 ,2 ),))
130
- test_rrule (repeat, rand (4 , 5 ); fkwargs = (inner= (1 ,2 ), outer= (1 ,3 )))
131
- test_rrule (repeat, rand (4 , 5 ); fkwargs = (outer= 2 ,))
130
+ @gpu_broken test_rrule (repeat, rand (4 , 5 ); fkwargs = (inner= (1 ,2 ), outer= (1 ,3 )))
131
+ @gpu_broken test_rrule (repeat, rand (4 , 5 ); fkwargs = (outer= 2 ,))
132
132
133
- test_rrule (repeat, rand (4 , ), 2 )
134
- test_rrule (repeat, rand (4 , 5 ), 2 )
135
- test_rrule (repeat, rand (4 , 5 ), 2 , 3 )
133
+ @gpu test_rrule (repeat, rand (4 , ), 2 )
134
+ @gpu test_rrule (repeat, rand (4 , 5 ), 2 )
135
+ @gpu test_rrule (repeat, rand (4 , 5 ), 2 , 3 )
136
136
test_rrule (repeat, rand (1 ,2 ,3 ), 2 ,3 ,4 ; check_inferred= VERSION > v " 1.6" )
137
137
test_rrule (repeat, rand (0 ,2 ,3 ), 2 ,0 ,4 ; check_inferred= VERSION > v " 1.6" )
138
138
test_rrule (repeat, rand (1 ,1 ,1 ,1 ), 2 ,3 ,4 ,5 ; check_inferred= VERSION > v " 1.6" )
@@ -153,16 +153,16 @@ end
153
153
154
154
@test rrule (repeat, [1 ,2 ,3 ], 4 )[2 ](ones (12 ))[2 ] == [4 ,4 ,4 ]
155
155
@test rrule (repeat, [1 ,2 ,3 ], outer= 4 )[2 ](ones (12 ))[2 ] == [4 ,4 ,4 ]
156
-
157
156
end
158
157
159
158
@testset " hcat" begin
160
159
# forward
161
- test_frule (hcat, randn (3 , 2 ), randn (3 ))
162
- test_frule (hcat, randn (), randn (1 ,3 ))
160
+ @gpu test_frule (hcat, randn (3 , 2 ), randn (3 ))
161
+ @gpu test_frule (hcat, randn (), randn (1 ,3 ))
163
162
164
163
# reverse
165
- test_rrule (hcat, randn (3 , 2 ), randn (3 ), randn (3 , 3 ))
164
+ @gpu test_rrule (hcat, randn (3 , 2 ), randn (3 ), randn (3 , 3 ))
165
+ @gpu test_rrule (hcat, rand (1 ,2 ), rand (), rand (1 ,3 ))
166
166
test_rrule (hcat, rand (), rand (1 ,2 ), rand (1 ,2 ,1 ))
167
167
test_rrule (hcat, rand (3 ,1 ,1 ,2 ), rand (3 ,3 ,1 ,2 ))
168
168
@@ -194,13 +194,14 @@ end
194
194
end
195
195
196
196
@testset " vcat" begin
197
-
198
197
# forward
199
198
test_frule (vcat, randn (), randn (3 ), rand ())
200
- test_frule (vcat, randn (3 , 1 ), randn (3 ))
199
+ @gpu test_frule (vcat, randn (3 ), rand (), randn (3 ))
200
+ @gpu test_frule (vcat, randn (3 , 1 ), randn (3 ))
201
201
202
202
# reverse
203
- test_rrule (vcat, randn (2 , 4 ), randn (1 , 4 ), randn (3 , 4 ))
203
+ @gpu test_rrule (vcat, randn (3 ), rand (), randn (3 ))
204
+ @gpu test_rrule (vcat, randn (2 , 4 ), randn (1 , 4 ), randn (3 , 4 ))
204
205
test_rrule (vcat, rand (), rand ())
205
206
test_rrule (vcat, rand (), rand (3 ), rand (3 ,1 ,1 ))
206
207
test_rrule (vcat, rand (3 ,1 ,2 ), rand (4 ,1 ,2 ))
230
231
test_frule (cat, rand (), rand (2 ,3 ); fkwargs= (dims= (1 ,2 ),))
231
232
232
233
# reverse
233
- test_rrule (cat, rand (2 , 4 ), rand (1 , 4 ); fkwargs= (dims= 1 ,))
234
- test_rrule (cat, rand (2 , 4 ), rand (2 ); fkwargs= (dims= Val (2 ),))
234
+ @gpu test_rrule (cat, rand (2 , 4 ), rand (1 , 4 ); fkwargs= (dims= 1 ,))
235
+ @gpu test_rrule (cat, rand (2 , 4 ), rand (2 ); fkwargs= (dims= Val (2 ),))
235
236
test_rrule (cat, rand (), rand (2 , 3 ); fkwargs= (dims= [1 ,2 ],))
236
237
test_rrule (cat, rand (1 ), rand (3 , 2 , 1 ); fkwargs= (dims= (1 ,2 ),), check_inferred= false ) # infers Tuple{Zero, Vector{Float64}, Any}
237
238
263
264
end
264
265
@testset " Array" begin
265
266
# Forward
266
- test_frule (reverse, rand (5 ))
267
+ @gpu_broken test_frule (reverse, rand (5 ))
267
268
test_frule (reverse, rand (5 ), 2 , 4 )
268
269
test_frule (reverse, rand (5 ), fkwargs= (dims= 1 ,))
269
270
test_frule (reverse, rand (3 ,4 ), fkwargs= (dims= 2 ,))
275
276
test_frule (reverse!, rand (3 ,4 ), fkwargs= (dims= 2 ,))
276
277
277
278
# Reverse
278
- test_rrule (reverse, rand (5 ))
279
+ @gpu_broken test_rrule (reverse, rand (5 ))
279
280
test_rrule (reverse, rand (5 ), 2 , 4 )
280
281
test_rrule (reverse, rand (5 ), fkwargs= (dims= 1 ,))
281
282
@@ -293,15 +294,15 @@ end
293
294
294
295
@testset " circshift" begin
295
296
# Forward
296
- test_frule (circshift, rand (10 ), 1 )
297
+ @gpu test_frule (circshift, rand (10 ), 1 )
297
298
test_frule (circshift, rand (10 ), (1 ,))
298
299
test_frule (circshift, rand (3 ,4 ), (- 7 ,2 ))
299
300
300
301
test_frule (circshift!, rand (10 ), rand (10 ), 1 )
301
302
test_frule (circshift!, rand (3 ,4 ), rand (3 ,4 ), (- 7 ,2 ))
302
303
303
304
# Reverse
304
- test_rrule (circshift, rand (10 ), 1 )
305
+ @gpu test_rrule (circshift, rand (10 ), 1 )
305
306
test_rrule (circshift, rand (10 ) .+ im, - 2 )
306
307
test_rrule (circshift, rand (10 ), (1 ,))
307
308
test_rrule (circshift, rand (3 ,4 ), (- 7 ,2 ))
@@ -379,14 +380,14 @@ end
379
380
# Forward
380
381
test_frule (imum, rand (10 ))
381
382
test_frule (imum, rand (3 ,4 ))
382
- test_frule (imum, rand (3 ,4 ), fkwargs= (dims= 1 ,))
383
+ @gpu_broken test_frule (imum, rand (3 ,4 ), fkwargs= (dims= 1 ,))
383
384
test_frule (imum, [rand (2 ) for _ in 1 : 3 ])
384
385
test_frule (imum, [rand (2 ) for _ in 1 : 3 , _ in 1 : 4 ]; fkwargs= (dims= 1 ,))
385
386
386
387
# Reverse
387
388
test_rrule (imum, rand (10 ))
388
389
test_rrule (imum, rand (3 ,4 ))
389
- test_rrule (imum, rand (3 ,4 ), fkwargs= (dims= 1 ,))
390
+ @gpu_broken test_rrule (imum, rand (3 ,4 ), fkwargs= (dims= 1 ,))
390
391
test_rrule (imum, rand (3 ,4 ,5 ), fkwargs= (dims= (1 ,3 ),))
391
392
392
393
# Arrays of arrays
0 commit comments