Skip to content

Commit ccd4196

Browse files
committed
just arrayify scalar, and also prefactorize A'
1 parent c75d920 commit ccd4196

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,21 +342,24 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
342342
project_B = ProjectTo(B)
343343

344344
Y = A \ B
345-
# Ever since https://github.com/JuliaLang/julia/pull/44358
346-
# we need to use `pinv` rather than `/` to support both the cases of Y being scalar and array
347-
# See also https://github.com/JuliaLang/julia/issues/28827 which would improve this
345+
346+
Atf = factorize(A')
347+
348348
function backslash_pullback(ȳ)
349349
= unthunk(ȳ)
350-
Ati = pinv(A')
350+
@static if VERSION >= v"1.9"
351+
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352+
isa AbstractArray ||= [Ȳ]
353+
end
354+
Atf = factorize(A')
351355
∂A = @thunk begin
352-
353-
= Ati *
356+
= Atf \
354357
= -* Y'
355-
= add!!(Ā, ((B - A * Y) *') * Ati)
356-
= add!!(Ā, Ati * Y * (Ȳ' -'A))
358+
= add!!(Ā, ((B - A * Y) *') / Atf)
359+
= add!!(Ā, Atf \ Y * (Ȳ' -'A))
357360
project_A(Ā)
358361
end
359-
∂B = @thunk project_B(Ati * Ȳ)
362+
∂B = @thunk project_B(Atf \ Ȳ)
360363
return NoTangent(), ∂A, ∂B
361364
end
362365
return Y, backslash_pullback

0 commit comments

Comments
 (0)