Skip to content

Commit fb8cacd

Browse files
committed
only do minimal change to rule for \ to convert to array
Also make second Y not scalar more coercing some things into arrays some of the time cleaner def with a helper function
1 parent ccd4196 commit fb8cacd

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,28 +343,71 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
343343

344344
Y = A \ B
345345

346-
Atf = factorize(A')
347-
348346
function backslash_pullback(ȳ)
349347
= unthunk(ȳ)
348+
349+
Ȳf =
350350
@static if VERSION >= v"1.9"
351351
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352-
isa AbstractArray ||= [Ȳ]
352+
if !isa(Ȳ, AbstractArray)
353+
Ȳf = [Ȳ]
354+
end
355+
end
356+
Yf = Y
357+
@static if VERSION >= v"1.9"
358+
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
359+
if !isa(Y, AbstractArray)
360+
Yf = [Y]
361+
end
353362
end
354-
Atf = factorize(A')
363+
#@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
355364
∂A = @thunk begin
356-
= Atf \
365+
= A' \ Ȳf
357366
= -* Y'
358-
= add!!(Ā, ((B - A * Y) *') / Atf)
359-
= add!!(Ā, Atf \ Y * (Ȳ' -'A))
367+
t = (B - A * Y) *'
368+
@static if VERSION >= v"1.9"
369+
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
370+
if !isa(t, AbstractArray)
371+
t = [t]
372+
end
373+
end
374+
= add!!(Ā, t / A')
375+
= add!!(Ā, A' \ Yf * (Ȳ' -'A))
360376
project_A(Ā)
361377
end
362-
∂B = @thunk project_B(Atf \ )
378+
∂B = @thunk project_B(A' \ Ȳf)
363379
return NoTangent(), ∂A, ∂B
364380
end
365381
return Y, backslash_pullback
366382
end
367383

384+
@static if VERSION >= v"1.9"
385+
# Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
386+
_maybe_descalar(x) = x isa AbstractArray ? x : [x]
387+
else
388+
_maybe_descalar(x) = x
389+
end
390+
391+
function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
392+
Y = A \ B
393+
394+
395+
function backslash_pullback(ȳ)
396+
= unthunk(ȳ)
397+
398+
∂A = @thunk begin
399+
= A' \ _maybe_descalar(Ȳ)
400+
= -* Y'
401+
+= _maybe_descalar((B - A * Y) *') / A'
402+
+= (A' \ _maybe_descalar(Y)) * (Ȳ' -'A)
403+
(Ā)
404+
end
405+
∂B = @thunk (A' \ _maybe_descalar(Ȳ))
406+
return ∂A, ∂B
407+
end
408+
return Y, backslash_pullback
409+
end
410+
368411
#####
369412
##### `\`, `/` matrix-scalar_rule
370413
#####

test/rulesets/Base/arraymath.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@
167167
@testset "Matrix $f Vector" begin
168168
X = randn(10, 4)
169169
y = randn(10)
170-
test_rrule(f, X, y)
170+
test_rrule(f, X, y; check_inferred=false)
171171
end
172172
@testset "Vector $f Matrix" begin
173173
x = randn(10)
174174
Y = randn(10, 4)
175-
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
175+
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false)
176176
end
177177
else
178178
A = rand(2, 4)

0 commit comments

Comments
 (0)