Skip to content

Commit 53d5f9e

Browse files
committed
Make frule return a scalar
1 parent 55dcefe commit 53d5f9e

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/rule_definition_tools.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,17 @@ macro scalar_rule(call, maybe_setup, partials...)
208208
∂s = partials[output_i].args
209209
propagation_expr(𝒟, Δs, ∂s)
210210
end
211-
211+
if n_outputs > 1
212+
# For forward-mode we only return a tuple if output actually a tuple.
213+
pushforward_returns = Expr(:tuple, pushforward_returns...)
214+
else
215+
pushforward_returns = pushforward_returns[1]
216+
end
212217
quote
213218
# _ is the input derivative w.r.t. function internals. since we do not
214219
# allow closures/functors with @scalar_rule, it is always ignored
215220
function $(propagator_name(f, :pushforward))(_, $(Δs...))
216-
return $(Expr(:tuple, pushforward_returns...))
221+
$pushforward_returns
217222
end
218223
end
219224
end

test/rules.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
3131

3232
frx, cool_pushforward = frule(cool, 1)
3333
@test frx == 2
34-
@test cool_pushforward(NamedTuple(), 1) == (1,)
34+
@test cool_pushforward(NamedTuple(), 1) == 1
3535
rrx, cool_pullback = rrule(cool, 1)
3636
self, rr1 = cool_pullback(1)
3737
@test self == NO_FIELDS
@@ -54,11 +54,11 @@ end
5454

5555
Δ = One()
5656
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
57-
@test df === (x + x,)
57+
@test df === x + x
5858

5959
Δ = rand(Complex{Int64})
6060
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
61-
@test df === (Δ * (x + x),)
61+
@test df === Δ * (x + x)
6262
end
6363

6464
@testset "complex input" begin
@@ -67,11 +67,11 @@ end
6767
@test f === abs2(z)
6868

6969
df = @inferred myabs2_pushforward(NamedTuple(), One())
70-
@test df === (Wirtinger(z', z),)
70+
@test df === Wirtinger(z', z)
7171

7272
Δ = rand(Complex{Int64})
7373
df = @inferred myabs2_pushforward(NamedTuple(), Δ)
74-
@test df === (Wirtinger* z', Δ * z),)
74+
@test df === Wirtinger* z', Δ * z)
7575
end
7676
end
7777

@@ -132,8 +132,8 @@ end
132132
fx, f_pushforward = res
133133
df(Δx, Δp) = f_pushforward(NamedTuple(), Δx, Δp)
134134

135-
df_dx, = df(One(), Zero())
136-
df_dp,= df(Zero(), One())
135+
df_dx = df(One(), Zero())
136+
df_dp = df(Zero(), One())
137137
@test fx == f(x, p) # Check we still get the normal value, right
138138
@test extern(df_dx) isa expected_type_df_dx
139139
@test extern(df_dp) isa expected_type_df_dp

0 commit comments

Comments
 (0)