Skip to content

Commit e05886a

Browse files
Kenooxinabox
authored andcommitted
Don't use the array muladd rule for ZeroTangent
1 parent 01ea92b commit e05886a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
351351

352352
function backslash_pullback(ȳ)
353353
= unthunk(ȳ)
354-
354+
355355
Ȳf =
356356
@static if VERSION >= v"1.9"
357357
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
@@ -360,7 +360,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
360360
end
361361
end
362362
Yf = Y
363-
@static if VERSION >= v"1.9"
363+
@static if VERSION >= v"1.9"
364364
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
365365
if !isa(Y, AbstractArray)
366366
Yf = [Y]
@@ -371,7 +371,7 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
371371
= A' \ Ȳf
372372
= -* Y'
373373
t = (B - A * Y) *'
374-
@static if VERSION >= v"1.9"
374+
@static if VERSION >= v"1.9"
375375
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
376376
if !isa(t, AbstractArray)
377377
t = [t]

src/rulesets/Base/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ end
9494

9595
@scalar_rule fma(x, y, z) (y, x, true)
9696
@scalar_rule muladd(x, y, z) (y, x, true)
97+
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)
9798
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
9899
@scalar_rule(
99100
mod(x, y),

0 commit comments

Comments
 (0)