Skip to content

Commit d5741e1

Browse files
committed
Don't ProjectTo in reverse or circshift
1 parent ddc4677 commit d5741e1

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.11.1"
3+
version = "1.11.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,9 @@ function frule((_, xdot), ::typeof(reverse), x::AbstractArray, args...; kw...)
303303
end
304304

305305
function rrule(::typeof(reverse), x::AbstractArray, args...; kw...)
306-
project = ProjectTo(x)
307306
nots = map(_ -> NoTangent(), args)
308307
function reverse_pullback(dy)
309-
dx = @thunk project(reverse(unthunk(dy), args...; kw...))
308+
dx = @thunk reverse(unthunk(dy), args...; kw...)
310309
return (NoTangent(), dx, nots...)
311310
end
312311
return reverse(x, args...; kw...), reverse_pullback
@@ -321,9 +320,8 @@ function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts)
321320
end
322321

323322
function rrule(::typeof(circshift), x::AbstractArray, shifts)
324-
project = ProjectTo(x)
325323
function circshift_pullback(dy)
326-
dx = @thunk project(circshift(unthunk(dy), map(-, shifts)))
324+
dx = @thunk circshift(unthunk(dy), map(-, shifts))
327325
# Note that circshift! is useless for InplaceableThunk, as it overwrites completely
328326
return (NoTangent(), dx, NoTangent())
329327
end

test/rulesets/Base/array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ end
171171

172172
# Structured
173173
y, pb = rrule(reverse, Diagonal([1,2,3]))
174-
@test unthunk(pb(rand(3,3))[2]) isa Diagonal
174+
# We only preserve structure in this case if given structured tangent (no ProjectTo)
175+
@test unthunk(pb(Diagonal([1.1, 2.1, 3.1]))[2]) isa Diagonal
176+
@test unthunk(pb(rand(3, 3))[2]) isa AbstractArray
175177
end
176178
end
177179

0 commit comments

Comments
 (0)