Skip to content

Commit 6df028a

Browse files
authored
Merge pull request #489 from mcabbott/reverse
Add rules for `reverse`, `circshift`
2 parents 2ba06a9 + eb3570d commit 6df028a

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed

src/rulesets/Base/array.jl

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,19 +262,55 @@ function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
262262
end
263263

264264
#####
265-
##### `fill`
265+
##### `reverse`
266266
#####
267267

268-
function rrule(::typeof(fill), value::Any, dims::Tuple{Vararg{Int}})
269-
function fill_pullback(Ȳ)
270-
return (NoTangent(), sum(Ȳ), NoTangent())
268+
# 1-dim case allows start/stop, N-dim case takes dims keyword
269+
# whose defaults changed in Julia 1.6... just pass them all through:
270+
271+
function frule((_, xdot), ::typeof(reverse), x::AbstractArray, args...; kw...)
272+
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
273+
end
274+
275+
function rrule(::typeof(reverse), x::AbstractArray, args...; kw...)
276+
project = ProjectTo(x)
277+
nots = map(_ -> NoTangent(), args)
278+
function reverse_pullback(dy)
279+
dx = @thunk project(reverse(unthunk(dy), args...; kw...))
280+
return (NoTangent(), dx, nots...)
271281
end
272-
return fill(value, dims), fill_pullback
282+
return reverse(x, args...; kw...), reverse_pullback
283+
end
284+
285+
#####
286+
##### `circshift`
287+
#####
288+
289+
function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts)
290+
return circshift(x, shifts), circshift(xdot, shifts)
273291
end
274292

275-
function rrule(::typeof(fill), value::Any, dims::Int...)
276-
function fill_pullback(Ȳ)
277-
return (NoTangent(), sum(Ȳ), ntuple(_->NoTangent(), length(dims))...)
293+
function rrule(::typeof(circshift), x::AbstractArray, shifts)
294+
project = ProjectTo(x)
295+
function circshift_pullback(dy)
296+
dx = @thunk project(circshift(unthunk(dy), map(-, shifts)))
297+
# Note that circshift! is useless for InplaceableThunk, as it overwrites completely
298+
return (NoTangent(), dx, NoTangent())
278299
end
279-
return fill(value, dims), fill_pullback
300+
return circshift(x, shifts), circshift_pullback
301+
end
302+
303+
#####
304+
##### `fill`
305+
#####
306+
307+
function frule((_, xdot), ::typeof(fill), x::Any, dims...)
308+
return fill(x, dims...), fill(xdot, dims...)
309+
end
310+
311+
function rrule(::typeof(fill), x::Any, dims...)
312+
project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
313+
nots = map(_ -> NoTangent(), dims)
314+
fill_pullback(Ȳ) = (NoTangent(), project(sum(Ȳ)), nots...)
315+
return fill(x, dims...), fill_pullback
280316
end

test/rulesets/Base/array.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,52 @@ end
101101
test_rrule(hvcat, 1, rand(3)', transpose(rand(3)) rand(1,3); check_inferred=VERSION>v"1.1")
102102
end
103103

104+
@testset "reverse" begin
105+
# Forward
106+
test_frule(reverse, rand(5))
107+
test_frule(reverse, rand(5), 2, 4)
108+
test_frule(reverse, rand(5), fkwargs=(dims=1,))
109+
110+
test_frule(reverse, rand(3,4), fkwargs=(dims=2,))
111+
if VERSION >= v"1.6"
112+
test_frule(reverse, rand(3,4))
113+
test_frule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
114+
end
115+
116+
# Reverse
117+
test_rrule(reverse, rand(5))
118+
test_rrule(reverse, rand(5), 2, 4)
119+
test_rrule(reverse, rand(5), fkwargs=(dims=1,))
120+
121+
test_rrule(reverse, rand(3,4), fkwargs=(dims=2,))
122+
if VERSION >= v"1.6"
123+
test_rrule(reverse, rand(3,4))
124+
test_rrule(reverse, rand(3,4,5), fkwargs=(dims=(1,3),))
125+
126+
# Structured
127+
y, pb = rrule(reverse, Diagonal([1,2,3]))
128+
@test unthunk(pb(rand(3,3))[2]) isa Diagonal
129+
end
130+
end
131+
132+
@testset "circshift" begin
133+
# Forward
134+
test_frule(circshift, rand(10), 1)
135+
test_frule(circshift, rand(10), (1,))
136+
test_frule(circshift, rand(3,4), (-7,2))
137+
138+
# Reverse
139+
test_rrule(circshift, rand(10), 1)
140+
test_rrule(circshift, rand(10) .+ im, -2)
141+
test_rrule(circshift, rand(10), (1,))
142+
test_rrule(circshift, rand(3,4), (-7,2))
143+
end
144+
104145
@testset "fill" begin
105-
test_rrule(fill, 44.0, 4; check_inferred=false)
106-
test_rrule(fill, 2.0, (3, 3, 3))
146+
test_frule(fill, 12.3, 4)
147+
test_frule(fill, 5.0, (6, 7))
148+
149+
test_rrule(fill, 44.4, 4)
150+
test_rrule(fill, 55 + 0.5im, 5)
151+
test_rrule(fill, 3.3, (3, 3, 3))
107152
end

0 commit comments

Comments
 (0)