@@ -343,28 +343,71 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
343
343
344
344
Y = A \ B
345
345
346
- Atf = factorize (A' )
347
-
348
346
function backslash_pullback (ȳ)
349
347
Ȳ = unthunk (ȳ)
348
+
349
+ Ȳf = Ȳ
350
350
@static if VERSION >= v " 1.9"
351
351
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352
- Ȳ isa AbstractArray || Ȳ = [Ȳ]
352
+ if ! isa (Ȳ, AbstractArray)
353
+ Ȳf = [Ȳ]
354
+ end
355
+ end
356
+ Yf = Y
357
+ @static if VERSION >= v " 1.9"
358
+ # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
359
+ if ! isa (Y, AbstractArray)
360
+ Yf = [Y]
361
+ end
353
362
end
354
- Atf = factorize (A ' )
363
+ # @info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B )
355
364
∂A = @thunk begin
356
- B̄ = Atf \ Ȳ
365
+ B̄ = A ' \ Ȳf
357
366
Ā = - B̄ * Y'
358
- Ā = add!! (Ā, ((B - A * Y) * B̄' ) / Atf)
359
- Ā = add!! (Ā, Atf \ Y * (Ȳ' - B̄' A))
367
+ t = (B - A * Y) * B̄'
368
+ @static if VERSION >= v " 1.9"
369
+ # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
370
+ if ! isa (t, AbstractArray)
371
+ t = [t]
372
+ end
373
+ end
374
+ Ā = add!! (Ā, t / A' )
375
+ Ā = add!! (Ā, A' \ Yf * (Ȳ' - B̄' A))
360
376
project_A (Ā)
361
377
end
362
- ∂B = @thunk project_B (Atf \ Ȳ )
378
+ ∂B = @thunk project_B (A ' \ Ȳf )
363
379
return NoTangent (), ∂A, ∂B
364
380
end
365
381
return Y, backslash_pullback
366
382
end
367
383
384
+ @static if VERSION >= v " 1.9"
385
+ # Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
386
+ _maybe_descalar (x) = x isa AbstractArray ? x : [x]
387
+ else
388
+ _maybe_descalar (x) = x
389
+ end
390
+
391
+ function rrule (A:: AbstractVecOrMat{<:Real} , B:: AbstractVecOrMat{<:Real} )
392
+ Y = A \ B
393
+
394
+
395
+ function backslash_pullback (ȳ)
396
+ Ȳ = unthunk (ȳ)
397
+
398
+ ∂A = @thunk begin
399
+ B̄ = A' \ _maybe_descalar (Ȳ)
400
+ Ā = - B̄ * Y'
401
+ Ā += _maybe_descalar ((B - A * Y) * B̄' ) / A'
402
+ Ā += (A' \ _maybe_descalar (Y)) * (Ȳ' - B̄' A)
403
+ (Ā)
404
+ end
405
+ ∂B = @thunk (A' \ _maybe_descalar (Ȳ))
406
+ return ∂A, ∂B
407
+ end
408
+ return Y, backslash_pullback
409
+ end
410
+
368
411
# ####
369
412
# #### `\`, `/` matrix-scalar_rule
370
413
# ####
0 commit comments