1
1
using ChainRulesCore, Test
2
2
using LinearAlgebra, SparseArrays
3
- using OffsetArrays, BenchmarkTools
3
+ using OffsetArrays, StaticArrays, BenchmarkTools
4
4
5
5
# Like ForwardDiff.jl's Dual
6
6
struct Dual{T<: Real } <: Real
@@ -295,7 +295,7 @@ struct NoSuperType end
295
295
# ####
296
296
297
297
@testset " OffsetArrays" begin
298
- # While there is no code for this, the rule that it checks axes(x) == axes(dx) else
298
+ # While there is no code for this, the rule that it checks axes(x) === axes(dx) else
299
299
# reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)
300
300
301
301
poffv = ProjectTo (OffsetArray (rand (3 ), 0 : 2 ))
@@ -304,8 +304,34 @@ struct NoSuperType end
304
304
305
305
@test axes (poffv (OffsetArray (rand (3 ), 0 : 2 ))) == (0 : 2 ,)
306
306
@test axes (poffv (OffsetArray (rand (3 , 1 ), 0 : 2 , 0 : 0 ))) == (0 : 2 ,)
307
+
308
+ pvec3 = ProjectTo ([1 , 2 , 3 ])
309
+ @test axes (pvec3 (OffsetArray (rand (3 ), 0 : 2 ))) == (1 : 3 ,)
310
+ @test pvec3 (OffsetArray (rand (3 ), 0 : 2 )) isa Vector # relies on axes === axes test
311
+ @test pvec3 (OffsetArray (rand (3 ,1 ), 0 : 2 , 0 : 0 )) isa Vector
307
312
end
308
313
314
+ # ####
315
+ # #### `StaticArrays`
316
+ # ####
317
+
318
+ @testset " StaticArrays" begin
319
+ # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
320
+ # implies a check, and reshape will wrap a Vector into a static SizedVector:
321
+ pstat = ProjectTo (SA[1 , 2 , 3 ])
322
+ @test axes (pstat (rand (3 ))) === (SOneTo (3 ),)
323
+
324
+ # This recurses into structured arrays:
325
+ pst = ProjectTo (transpose (SA[1 , 2 , 3 ]))
326
+ @test axes (pst (rand (1 ,3 ))) === (SOneTo (1 ), SOneTo (3 ))
327
+ @test pst (rand (1 ,3 )) isa Transpose
328
+
329
+ # When the argument is an ordinary Array, static gradients are allowed to pass,
330
+ # like FillArrays. Collecting to an Array would cost a copy.
331
+ pvec3 = ProjectTo ([1 , 2 , 3 ])
332
+ @test pvec3 (SA[1 , 2 , 3 ]) isa StaticArray
333
+ end
334
+
309
335
# ####
310
336
# #### `ChainRulesCore`
311
337
# ####
0 commit comments