Skip to content

Commit 7faaf5d

Browse files
authored
Rule for map(f, ::Tuple...) (#642)
* map for tuples, take 1 * fixup * sum(f, ::Tuple) rule * fix tests for map * test for mismatched lengths * add a warning for mistmatched lengths * inference test etc * unthunk, rename, etc * version
1 parent 8073c7c commit 7faaf5d

File tree

5 files changed

+69
-1
lines changed

5 files changed

+69
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.38.0"
3+
version = "1.39.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/base.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,44 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
200200
cube_pullback(dy) = (NoTangent(), NoTangent(), ProjectTo(x)(3 * x2 * dy), NoTangent())
201201
return x2 * x, cube_pullback
202202
end
203+
204+
#####
205+
##### `map`
206+
#####
207+
208+
# Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting
209+
# this may matter with a stateful `f`, but in general their order isn't guaranteed anyway,
210+
# so it's unclear how much effort should be spent on that. But `map` on Tuples normally
211+
# gets unrolled, so perhaps it does guarantee order, and reversing it should be cheap too.
212+
213+
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tuple...) where {F}
214+
length_y = minimum(length, xs)
215+
hobbits = ntuple(length_y) do i
216+
args = getindex.(xs, i)
217+
rrule_via_ad(config, f, args...)
218+
end
219+
y = map(first, hobbits)
220+
num_xs = Val(length(xs))
221+
paddings = map(x -> ntuple(Returns(NoTangent()), (length(x) - length_y)), xs)
222+
all(isempty, paddings) || @error """map(f, xs::Tuple...) does not allow mistmatched lengths!
223+
But its `rrule` does; when JuliaLang/julia #42216 is fixed this warning should be removed."""
224+
function map_pullback(dy_raw)
225+
dy = unthunk(dy_raw)
226+
# We want to call the pullbacks in `rrule_via_ad` in reverse sequence to the forward pass:
227+
backevals = ntuple(length_y) do i
228+
rev_i = length_y - i + 1
229+
last(hobbits[rev_i])(dy[rev_i])
230+
end |> reverse
231+
# This df doesn't infer, could test Base.issingletontype(F), but it's not the only inference problem.
232+
df = ProjectTo(f)(sum(first, backevals))
233+
# Now unzip that. Because `map` like `zip` should when any `x` stops, some `dx`s may need padding.
234+
# Although in fact, `map(+, (1,2), (3,4,5))` is an error... https://github.com/JuliaLang/julia/issues/42216
235+
dxs = ntuple(num_xs) do k
236+
dx_short = map(bv -> bv[k+1], backevals)
237+
ProjectTo(xs[k])((dx_short..., paddings[k]...)) # ProjectTo makes the Tangent for us
238+
end
239+
return (NoTangent(), df, dxs...)
240+
end
241+
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
242+
return y, map_pullback
243+
end

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ end
6262
##### `sum(f, x)`
6363
#####
6464

65+
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f::F, xs::Tuple) where {F}
66+
fxs, unmap = rrule(config, map, f, xs)
67+
y, unsum = rrule(config, sum, fxs)
68+
function sum_pullback_f(dy)
69+
_, dfxs = unsum(dy)
70+
_, df, dxs = unmap(dfxs)
71+
(NoTangent(), df, dxs)
72+
end
73+
y, sum_pullback_f
74+
end
75+
6576
function rrule(
6677
config::RuleConfig{>:HasReverseMode},
6778
::typeof(sum),

test/rulesets/Base/base.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,16 @@
216216
@test frule(NoRules, 1.0) === nothing
217217
@test rrule(NoRules, 1.0) === nothing
218218
end
219+
220+
@testset "map(f, ::Tuple...)" begin
221+
test_rrule(map, identity, (1.0, 2.0), check_inferred=false)
222+
test_rrule(map, +, (1.0, 2.0), (3.0, 4.0), check_inferred=false)
223+
test_rrule(map, make_two_vec, (4.0, 5.0 + 6im), check_inferred=false)
224+
test_rrule(map, Multiplier(rand() + im), Tuple(rand(3)), check_inferred=false)
225+
226+
if try map(+, (1,), (2,3)); true catch e; false end
227+
# True when https://github.com/JuliaLang/julia/issues/42216 has been fixed
228+
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
229+
end
230+
end
219231
end

test/rulesets/Base/mapreduce.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
6767
end
6868
end # sum abs2
6969

70+
@testset "sum(f, xs::Tuple)" begin
71+
test_rrule(sum, sqrt, Tuple(rand(3)), check_inferred=false)
72+
end
73+
7074
@testset "sum(f, xs)" begin
7175
# This calls back into AD
7276
test_rrule(sum, abs, [-4.0, 2.0, 2.0])

0 commit comments

Comments
 (0)