Skip to content

Commit dd1f30a

Browse files
authored
Merge pull request #752 from JuliaDiff/ox/muladd_zero_tangent
muladd rule for zero tangent
2 parents 01ea92b + 81c6e8c commit dd1f30a

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
351351

352352
function backslash_pullback(ȳ)
353353
= unthunk(ȳ)
354-
354+
355355
Ȳf =
356356
@static if VERSION >= v"1.9"
357357
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
@@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
360360
end
361361
end
362362
Yf = Y
363-
@static if VERSION >= v"1.9"
363+
@static if VERSION >= v"1.9"
364364
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
365365
if !isa(Y, AbstractArray)
366366
Yf = [Y]
@@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
371371
= A' \ Ȳf
372372
= -* Y'
373373
t = (B - A * Y) *'
374-
@static if VERSION >= v"1.9"
374+
@static if VERSION >= v"1.9"
375375
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
376376
if !isa(t, AbstractArray)
377377
t = [t]

src/rulesets/Base/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ end
9494

9595
@scalar_rule fma(x, y, z) (y, x, true)
9696
@scalar_rule muladd(x, y, z) (y, x, true)
97+
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)
9798
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
9899
@scalar_rule(
99100
mod(x, y),

test/rulesets/Base/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@
153153
test_rrule(muladd, 10randn(), randn(), randn())
154154
end
155155

156+
@testset "muladd ZeroTangent" begin
157+
test_frule(muladd, 2.0, 3.0, ZeroTangent())
158+
test_frule(muladd, 2.0, ZeroTangent(), 4.0)
159+
test_frule(muladd, ZeroTangent(), 3.0, 4.0)
160+
161+
test_rrule(muladd, 2.0, 3.0, ZeroTangent())
162+
test_rrule(muladd, 2.0, ZeroTangent(), 4.0)
163+
test_rrule(muladd, ZeroTangent(), 3.0, 4.0)
164+
end
165+
156166
@testset "fma" begin
157167
test_frule(fma, 10randn(), randn(), randn())
158168
test_rrule(fma, 10randn(), randn(), randn())

0 commit comments

Comments
 (0)