Skip to content

Commit 27c272d

Browse files
authored
rrule for many-arg * (#547)
* rrule for many-arg * * tweak * unrelated, remove a dot from muladd * bump version * Revert "unrelated, remove a dot from muladd" This reverts commit eb5b761. * fixup * skip inference on 1.0
1 parent ce78d3d commit 27c272d

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.14.0"
3+
version = "1.15.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/fastmath_able.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,19 +231,46 @@ let
231231
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
232232
# accurate on machines with FMA instructions, since there are only two
233233
# rounding operations, one in `muladd/fma` and the other in `*`.
234-
∂xy = muladd.(Δx, y, x .* Δy)
234+
∂xy = muladd(Δx, y, x * Δy)
235235
return x * y, ∂xy
236236
end
237+
frule((_, Δx), ::typeof(*), x::Number) = x, Δx
237238

238239
function rrule(::typeof(*), x::Number, y::Number)
239-
project_x = ProjectTo(x)
240-
project_y = ProjectTo(y)
241-
function times_pullback(Ω̇)
240+
function times_pullback2(Ω̇)
242241
ΔΩ = unthunk(Ω̇)
243-
return (NoTangent(), project_x(ΔΩ * y'), project_y(x' * ΔΩ))
242+
return (NoTangent(), ProjectTo(x)(ΔΩ * y'), ProjectTo(y)(x' * ΔΩ))
243+
end
244+
return x * y, times_pullback2
245+
end
246+
# While 3-arg * calls 2-arg *, this is currently slow in Zygote:
247+
# https://github.com/JuliaDiff/ChainRules.jl/issues/544
248+
function rrule(::typeof(*), x::Number, y::Number, z::Number)
249+
function times_pullback3(Ω̇)
250+
ΔΩ = unthunk(Ω̇)
251+
return (
252+
NoTangent(),
253+
ProjectTo(x)(ΔΩ * y' * z'),
254+
ProjectTo(y)(x' * ΔΩ * z'),
255+
ProjectTo(z)(x' * y' * ΔΩ),
256+
)
257+
end
258+
return x * y * z, times_pullback3
259+
end
260+
# Instead of this recursive rule for N args, you could write the generic case
261+
# directly, by nesting ntuples, but this didn't infer well:
262+
# https://github.com/JuliaDiff/ChainRules.jl/pull/547/commits/3558951c9f1b3c70e7135fd61d29cc3b96a68dea
263+
function rrule(::typeof(*), x::Number, y::Number, z::Number, more::Number...)
264+
Ω3, back3 = rrule(*, x, y, z)
265+
Ω4, back4 = rrule(*, Ω3, more...)
266+
function times_pullback4(Ω̇)
267+
Δ4 = back4(unthunk(Ω̇)) # (0, ΔΩ3, Δmore...)
268+
Δ3 = back3(Δ4[2]) # (0, Δx, Δy, Δz)
269+
return (Δ3..., Δ4[3:end]...)
244270
end
245-
return x * y, times_pullback
271+
return Ω4, times_pullback4
246272
end
273+
rrule(::typeof(*), x::Number) = rrule(identity, x)
247274
end # fastable_ast
248275

249276
# Rewrite everything to use fast_math functions, including the type-constraints

test/rulesets/Base/base.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@
104104
test_frule(*, x, y)
105105
test_rrule(*, x, y)
106106
end
107+
@testset "*($x, $y, ...)" for x in test_points, y in test_points
108+
# This promotion is only for FiniteDifferences, the rules allow mixtures:
109+
x, y = Base.promote(x, y)
110+
111+
# Inference fails on 1.0, passes on 1.6
112+
test_rrule(*, x, y, x+y; check_inferred=VERSION>v"1.5")
113+
test_rrule(*, x, y, 17x, 23y; check_inferred=VERSION>v"1.5")
114+
test_rrule(*, x, y, 7x, 3y, x+y+pi; check_inferred=VERSION>v"1.5")
115+
end
107116
end
108117

109118
@testset "ldexp" begin

test/rulesets/Base/fastmath_able.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ const FASTABLE_AST = quote
120120
test_scalar(+, x)
121121
test_scalar(-, x)
122122
test_scalar(atan, x)
123+
test_scalar(*, x)
123124
end
124125
end
125126

@@ -132,7 +133,7 @@ const FASTABLE_AST = quote
132133
test_rrule(f, (rand(0:10) + .6rand() + .2) * base, base)
133134
end
134135

135-
@testset "$f(x::$T, y::$T)" for f in (/, +, -, hypot), T in (Float64, ComplexF64)
136+
@testset "$f(x::$T, y::$T)" for f in (/, +, -, *, hypot), T in (Float64, ComplexF64)
136137
test_frule(f, 10rand(T), rand(T))
137138
test_rrule(f, 10rand(T), rand(T))
138139
end

0 commit comments

Comments
 (0)