Skip to content

Commit e0ecaee

Browse files
authored
Compare axes with === before reshaping (#480)
* test axes === axes * comment etc * v1.7.3
1 parent 32c7dbf commit e0ecaee

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "1.7.2"
3+
version = "1.7.3"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/projection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ end
218218
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
219219
# First deal with shape. The rule is that we reshape to add or remove trivial dimensions
220220
# like dx = ones(4,1), where x = ones(4), but throw an error on dx = ones(1,4) etc.
221-
dy = if axes(dx) == project.axes
221+
dy = if axes(dx) === project.axes
222222
dx
223223
else
224224
for d in 1:max(M, length(project.axes))

test/projection.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ChainRulesCore, Test
22
using LinearAlgebra, SparseArrays
3-
using OffsetArrays, BenchmarkTools
3+
using OffsetArrays, StaticArrays, BenchmarkTools
44

55
# Like ForwardDiff.jl's Dual
66
struct Dual{T<:Real} <: Real
@@ -295,7 +295,7 @@ struct NoSuperType end
295295
#####
296296

297297
@testset "OffsetArrays" begin
298-
# While there is no code for this, the rule that it checks axes(x) == axes(dx) else
298+
# While there is no code for this, the rule that it checks axes(x) === axes(dx) else
299299
# reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)
300300

301301
poffv = ProjectTo(OffsetArray(rand(3), 0:2))
@@ -304,8 +304,34 @@ struct NoSuperType end
304304

305305
@test axes(poffv(OffsetArray(rand(3), 0:2))) == (0:2,)
306306
@test axes(poffv(OffsetArray(rand(3, 1), 0:2, 0:0))) == (0:2,)
307+
308+
pvec3 = ProjectTo([1, 2, 3])
309+
@test axes(pvec3(OffsetArray(rand(3), 0:2))) == (1:3,)
310+
@test pvec3(OffsetArray(rand(3), 0:2)) isa Vector # relies on axes === axes test
311+
@test pvec3(OffsetArray(rand(3,1), 0:2, 0:0)) isa Vector
307312
end
308313

314+
#####
315+
##### `StaticArrays`
316+
#####
317+
318+
@testset "StaticArrays" begin
319+
# There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
320+
# implies a check, and reshape will wrap a Vector into a static SizedVector:
321+
pstat = ProjectTo(SA[1, 2, 3])
322+
@test axes(pstat(rand(3))) === (SOneTo(3),)
323+
324+
# This recurses into structured arrays:
325+
pst = ProjectTo(transpose(SA[1, 2, 3]))
326+
@test axes(pst(rand(1,3))) === (SOneTo(1), SOneTo(3))
327+
@test pst(rand(1,3)) isa Transpose
328+
329+
# When the argument is an ordinary Array, static gradients are allowed to pass,
330+
# like FillArrays. Collecting to an Array would cost a copy.
331+
pvec3 = ProjectTo([1, 2, 3])
332+
@test pvec3(SA[1, 2, 3]) isa StaticArray
333+
end
334+
309335
#####
310336
##### `ChainRulesCore`
311337
#####

0 commit comments

Comments
 (0)