@@ -342,20 +342,70 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
342
342
project_B = ProjectTo (B)
343
343
344
344
Y = A \ B
345
+
345
346
function backslash_pullback (ȳ)
346
347
Ȳ = unthunk (ȳ)
348
+
349
+ Ȳf = Ȳ
350
+ @static if VERSION >= v " 1.9"
351
+ # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352
+ if ! isa (Ȳ, AbstractArray)
353
+ Ȳf = [Ȳ]
354
+ end
355
+ end
356
+ Yf = Y
357
+ @static if VERSION >= v " 1.9"
358
+ # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
359
+ if ! isa (Y, AbstractArray)
360
+ Yf = [Y]
361
+ end
362
+ end
363
+ # @info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
347
364
∂A = @thunk begin
348
- B̄ = A' \ Ȳ
365
+ B̄ = A' \ Ȳf
349
366
Ā = - B̄ * Y'
350
- Ā = add!! (Ā, (B - A * Y) * B̄' / A' )
351
- Ā = add!! (Ā, A' \ Y * (Ȳ' - B̄' A))
367
+ t = (B - A * Y) * B̄'
368
+ @static if VERSION >= v " 1.9"
369
+ # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
370
+ if ! isa (t, AbstractArray)
371
+ t = [t]
372
+ end
373
+ end
374
+ Ā = add!! (Ā, t / A' )
375
+ Ā = add!! (Ā, A' \ Yf * (Ȳ' - B̄' A))
352
376
project_A (Ā)
353
377
end
354
- ∂B = @thunk project_B (A' \ Ȳ )
378
+ ∂B = @thunk project_B (A' \ Ȳf )
355
379
return NoTangent (), ∂A, ∂B
356
380
end
357
381
return Y, backslash_pullback
382
+ end
383
+
384
+ @static if VERSION >= v " 1.9"
385
+ # Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
386
+ _maybe_descalar (x) = x isa AbstractArray ? x : [x]
387
+ else
388
+ _maybe_descalar (x) = x
389
+ end
390
+
391
+ function rrule (A:: AbstractVecOrMat{<:Real} , B:: AbstractVecOrMat{<:Real} )
392
+ Y = A \ B
393
+
394
+
395
+ function backslash_pullback (ȳ)
396
+ Ȳ = unthunk (ȳ)
358
397
398
+ ∂A = @thunk begin
399
+ B̄ = A' \ _maybe_descalar (Ȳ)
400
+ Ā = - B̄ * Y'
401
+ Ā += _maybe_descalar ((B - A * Y) * B̄' ) / A'
402
+ Ā += (A' \ _maybe_descalar (Y)) * (Ȳ' - B̄' A)
403
+ (Ā)
404
+ end
405
+ ∂B = @thunk (A' \ _maybe_descalar (Ȳ))
406
+ return ∂A, ∂B
407
+ end
408
+ return Y, backslash_pullback
359
409
end
360
410
361
411
# ####
0 commit comments