Skip to content

Commit 9db4c9c

Browse files
authored
Merge pull request #534 from JuliaDiff/ox/sumbr
Handle sum(f,xs) where f returns a Vector
2 parents d76945e + 7b4a3f1 commit 9db4c9c

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ function rrule(
7474
project = ProjectTo(xs)
7575

7676
function sum_pullback(ȳ)
77-
call(f, x) = f(x) # we need to broadcast this to handle dims kwarg
78-
f̄_and_x̄s = call.(pullbacks, ȳ)
77+
call(f, x) = f(x)
78+
# if dims is :, then need only left-handed only broadcast
79+
broadcast_ȳ = dims isa Colon ? (ȳ,) :
80+
f̄_and_x̄s = call.(pullbacks, broadcast_ȳ)
7981
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
8082
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
8183
NoTangent()

test/rulesets/Base/mapreduce.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end
9292
test_rrule(sum, sqrt, randn(5,5) .> 0; fkwargs=(;dims=1))
9393
# ... and Bool produced by function
9494
@test_skip test_rrule(sum, iszero, randn(5)) # DimensionMismatch("second dimension of A, 1, does not match length of x, 0")
95+
96+
97+
# Functions that return a Vector
98+
# see https://github.com/FluxML/Zygote.jl/issues/1074
99+
test_rrule(sum, make_two_vec, [1.0, 3.0, 5.0, 7.0])
100+
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0])
101+
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=2))
102+
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=1))
103+
test_rrule(sum, make_two_vec, [1.0 2.0; 3.0 4.0]; fkwargs=(;dims=(3, 4)))
95104
end
96105

97106
# https://github.com/JuliaDiff/ChainRules.jl/issues/522

test/test_helpers.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,18 @@ end
1111
# NoRules - has no rules defined
1212
struct NoRules; end
1313

14+
"A function that outputs a vector from a scalar for testing"
15+
make_two_vec(x) = [x, x]
16+
function ChainRulesCore.rrule(::typeof(make_two_vec), x)
17+
make_two_vec_pullback(ȳ) = (NoTangent(), sum(ȳ))
18+
return make_two_vec(x), make_two_vec_pullback
19+
end
20+
1421
@testset "test_helpers.jl" begin
15-
@testset "Multiplier functor test-helper" begin
22+
@testset "Multiplier functor" begin
1623
test_rrule(Multiplier(4.0), 3.0)
1724
end
18-
end
25+
@testset "make_two_vec" begin
26+
test_rrule(make_two_vec, 1.5)
27+
end
28+
end

0 commit comments

Comments
 (0)