Skip to content

Commit 50d9d03

Browse files
authored
Merge pull request #718 from JuliaDiff/ox/1.9fixes
Fix for julia 1.9
2 parents 11c230c + fb8cacd commit 50d9d03

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,20 +342,70 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
342342
project_B = ProjectTo(B)
343343

344344
Y = A \ B
345+
345346
function backslash_pullback(ȳ)
346347
= unthunk(ȳ)
348+
349+
Ȳf =
350+
@static if VERSION >= v"1.9"
351+
# Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352+
if !isa(Ȳ, AbstractArray)
353+
Ȳf = [Ȳ]
354+
end
355+
end
356+
Yf = Y
357+
@static if VERSION >= v"1.9"
358+
# Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
359+
if !isa(Y, AbstractArray)
360+
Yf = [Y]
361+
end
362+
end
363+
#@info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B)
347364
∂A = @thunk begin
348-
= A' \
365+
= A' \ Ȳf
349366
= -* Y'
350-
= add!!(Ā, (B - A * Y) *' / A')
351-
= add!!(Ā, A' \ Y * (Ȳ' -'A))
367+
t = (B - A * Y) *'
368+
@static if VERSION >= v"1.9"
369+
# Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
370+
if !isa(t, AbstractArray)
371+
t = [t]
372+
end
373+
end
374+
= add!!(Ā, t / A')
375+
= add!!(Ā, A' \ Yf * (Ȳ' -'A))
352376
project_A(Ā)
353377
end
354-
∂B = @thunk project_B(A' \ )
378+
∂B = @thunk project_B(A' \ Ȳf)
355379
return NoTangent(), ∂A, ∂B
356380
end
357381
return Y, backslash_pullback
382+
end
383+
384+
@static if VERSION >= v"1.9"
385+
# Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
386+
_maybe_descalar(x) = x isa AbstractArray ? x : [x]
387+
else
388+
_maybe_descalar(x) = x
389+
end
390+
391+
function rrule(A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
392+
Y = A \ B
393+
394+
395+
function backslash_pullback(ȳ)
396+
= unthunk(ȳ)
358397

398+
∂A = @thunk begin
399+
= A' \ _maybe_descalar(Ȳ)
400+
= -* Y'
401+
+= _maybe_descalar((B - A * Y) *') / A'
402+
+= (A' \ _maybe_descalar(Y)) * (Ȳ' -'A)
403+
(Ā)
404+
end
405+
∂B = @thunk (A' \ _maybe_descalar(Ȳ))
406+
return ∂A, ∂B
407+
end
408+
return Y, backslash_pullback
359409
end
360410

361411
#####

test/rulesets/Base/arraymath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testset "arraymath.jl" begin
22
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
33
B = generate_well_conditioned_matrix(T, 3)
4-
if VERSION >= v"1.7"
4+
if v"1.7" <= VERSION < v"1.9"
55
@gpu test_frule(inv, B)
66
@gpu test_rrule(inv, B)
77
else
@@ -167,12 +167,12 @@
167167
@testset "Matrix $f Vector" begin
168168
X = randn(10, 4)
169169
y = randn(10)
170-
test_rrule(f, X, y)
170+
test_rrule(f, X, y; check_inferred=false)
171171
end
172172
@testset "Vector $f Matrix" begin
173173
x = randn(10)
174174
Y = randn(10, 4)
175-
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
175+
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)), check_inferred=false)
176176
end
177177
else
178178
A = rand(2, 4)

0 commit comments

Comments
 (0)