@@ -217,16 +217,24 @@ end
217
217
# DimensionMismatch("second dimension of A, 6, does not match length of x, 5")
218
218
# Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator)
219
219
220
- test_rrule (collect∘ eachrow, rand (5 ))
221
- test_rrule (collect∘ eachrow, rand (3 , 4 ))
220
+ # Inference on 1.6 sometimes fails, so don't enforce there.
221
+ test_rrule (collect ∘ eachrow, rand (5 ); check_inferred= (VERSION >= v " 1.7" ))
222
+ test_rrule (collect ∘ eachrow, rand (3 , 4 ); check_inferred= (VERSION >= v " 1.7" ))
222
223
223
- test_rrule (collect∘ eachcol, rand (3 , 4 ))
224
- @test_skip test_rrule (collect∘ eachcol, Diagonal (rand (5 ))) # works locally!
224
+ test_rrule (collect ∘ eachcol, rand (3 , 4 ); check_inferred = ( VERSION >= v " 1.7 " ))
225
+ @test_skip test_rrule (collect ∘ eachcol, Diagonal (rand (5 ))) # works locally!
225
226
226
227
if VERSION >= v " 1.7"
227
228
# On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use.
228
- test_rrule (collect∘ eachslice, rand (3 , 4 , 5 ); fkwargs = (; dims = 3 ))
229
- test_rrule (collect∘ eachslice, rand (3 , 4 , 5 ); fkwargs = (; dims = (2 ,)))
229
+ test_rrule (collect ∘ eachslice, rand (3 , 4 , 5 ); fkwargs= (; dims= 3 ))
230
+ test_rrule (collect ∘ eachslice, rand (3 , 4 , 5 ); fkwargs= (; dims= (2 ,)))
231
+
232
+ test_rrule (
233
+ collect ∘ eachslice,
234
+ FooTwoField .(rand (3 , 4 , 5 ), rand (3 , 4 , 5 ));
235
+ check_inferred= false ,
236
+ fkwargs= (; dims= 3 ),
237
+ )
230
238
end
231
239
232
240
# Make sure pulling back an array that mixes some AbstractZeros in works right
235
243
@test back ([1 : 3 , ZeroTangent (), 7 : 9 , NoTangent ()])[2 ] isa Matrix{Float64}
236
244
@test back ([ZeroTangent (), ZeroTangent (), NoTangent (), NoTangent ()]) == (NoTangent (), [0 0 0 0 ; 0 0 0 0 ; 0 0 0 0 ])
237
245
246
+ _, back = ChainRules. rrule (
247
+ eachslice, FooTwoField .(rand (2 , 3 , 2 ), rand (2 , 3 , 2 )); dims= 3
248
+ )
249
+ @test back ([fill (Tangent {Any} (; x= 0.0 , y= 1.0 ), 2 , 3 ), fill (ZeroTangent (), 2 , 3 )]) == (
250
+ NoTangent (),
251
+ cat (fill (Tangent {Any} (; x= 0.0 , y= 1.0 ), 2 , 3 ), fill (ZeroTangent (), 2 , 3 ); dims= 3 ),
252
+ )
253
+
238
254
# Second derivative rule
239
255
test_rrule (ChainRules.∇eachslice, [rand (4 ) for _ in 1 : 3 ], rand (3 , 4 ), Val (1 ))
240
256
test_rrule (ChainRules.∇eachslice, [rand (3 ) for _ in 1 : 4 ], rand (3 , 4 ), Val (2 ))
241
- test_rrule (ChainRules.∇eachslice, [rand (2 , 3 ) for _ in 1 : 4 ], rand (2 , 3 , 4 ), Val (3 ), check_inferred= false )
257
+ test_rrule (
258
+ ChainRules.∇eachslice,
259
+ [rand (2 , 3 ) for _ in 1 : 4 ],
260
+ rand (2 , 3 , 4 ),
261
+ Val (3 );
262
+ check_inferred= (VERSION >= v " 1.7" ),
263
+ )
242
264
end
0 commit comments