Skip to content

Commit c41137e

Browse files
authored
Work around LinearAlgebra.jl breakage in 1.11.2. (#2585)
1 parent 3f51328 commit c41137e

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

lib/cublas/linalg.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,19 @@ LinearAlgebra.generic_trimatmul!(c::StridedCuVector{T}, uploc, isunitc, tfun::Fu
202202
LinearAlgebra.generic_trimatdiv!(C::StridedCuVector{T}, uploc, isunitc, tfun::Function, A::StridedCuMatrix{T}, B::StridedCuVector{T}) where {T<:CublasFloat} =
203203
trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, A, C === B ? C : copyto!(C, B))
204204

205+
# work around upstream breakage from JuliaLang/julia#55547
206+
@static if VERSION == v"1.11.2"
207+
const CuUpperOrUnitUpperTriangular = LinearAlgebra.UpperOrUnitUpperTriangular{
208+
<:Any,<:Union{<:CuArray, Adjoint{<:Any, <:CuArray}, Transpose{<:Any, <:CuArray}}}
209+
const CuLowerOrUnitLowerTriangular = LinearAlgebra.LowerOrUnitLowerTriangular{
210+
<:Any,<:Union{<:CuArray, Adjoint{<:Any, <:CuArray}, Transpose{<:Any, <:CuArray}}}
211+
212+
LinearAlgebra.istriu(::CuUpperOrUnitUpperTriangular) = true
213+
LinearAlgebra.istril(::CuUpperOrUnitUpperTriangular) = false
214+
LinearAlgebra.istriu(::CuLowerOrUnitLowerTriangular) = false
215+
LinearAlgebra.istril(::CuLowerOrUnitLowerTriangular) = true
216+
end
217+
205218

206219
#
207220
# BLAS 3

lib/cusparse/interfaces.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,19 @@ for SparseMatrixType in (:CuSparseMatrixBSR,)
243243
end
244244
end # SparseMatrixType loop
245245

246+
# work around upstream breakage from JuliaLang/julia#55547
247+
@static if VERSION == v"1.11.2"
248+
const CuSparseUpperOrUnitUpperTriangular = LinearAlgebra.UpperOrUnitUpperTriangular{
249+
<:Any,<:Union{<:AbstractCuSparseArray, Adjoint{<:Any, <:AbstractCuSparseArray}, Transpose{<:Any, <:AbstractCuSparseArray}}}
250+
const CuSparseLowerOrUnitLowerTriangular = LinearAlgebra.LowerOrUnitLowerTriangular{
251+
<:Any,<:Union{<:AbstractCuSparseArray, Adjoint{<:Any, <:AbstractCuSparseArray}, Transpose{<:Any, <:AbstractCuSparseArray}}}
252+
253+
LinearAlgebra.istriu(::CuSparseUpperOrUnitUpperTriangular) = true
254+
LinearAlgebra.istril(::CuSparseUpperOrUnitUpperTriangular) = false
255+
LinearAlgebra.istriu(::CuSparseLowerOrUnitLowerTriangular) = false
256+
LinearAlgebra.istril(::CuSparseLowerOrUnitLowerTriangular) = true
257+
end
258+
246259
for SparseMatrixType in (:CuSparseMatrixCOO, :CuSparseMatrixCSR, :CuSparseMatrixCSC)
247260
@eval begin
248261
function LinearAlgebra.generic_trimatdiv!(C::DenseCuVector{T}, uploc, isunitc, tfun::Function, A::$SparseMatrixType{T}, B::DenseCuVector{T}) where {T<:BlasFloat}

test/libraries/cublas.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,6 @@ end
478478
y = UpperTriangular(A) * x
479479
@test y Array(dy)
480480
end
481-
if VERSION != v"1.11.2" # https://github.com/JuliaLang/julia/pull/55764
482481
@testset "lmul!(::UpperTriangular{Adjoint})" begin
483482
dy = copy(dx)
484483
lmul!(adjoint(UpperTriangular(dA)), dy)
@@ -497,7 +496,6 @@ end
497496
y = LowerTriangular(A) * x
498497
@test y Array(dy)
499498
end
500-
end
501499
@testset "lmul!(::LowerTriangular{Adjoint})" begin
502500
dy = copy(dx)
503501
lmul!(adjoint(LowerTriangular(dA)), dy)
@@ -544,7 +542,6 @@ end
544542
@test y h_y
545543
end
546544

547-
if VERSION != v"1.11.2" # https://github.com/JuliaLang/julia/pull/55764
548545
@testset "ldiv!(::UpperTriangular)" begin
549546
A = copy(sA)
550547
dA = CuArray(A)
@@ -597,7 +594,6 @@ end
597594
@testset "inv($TR)" for TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
598595
@test testf(x -> inv(TR(x)), rand(elty, m, m))
599596
end
600-
end
601597

602598
A = rand(elty,m,m)
603599
x = rand(elty,m)
@@ -1036,7 +1032,6 @@ end
10361032
end
10371033
end
10381034

1039-
if VERSION != v"1.11.2" # https://github.com/JuliaLang/julia/pull/55764
10401035
@testset "triangular ldiv!" begin
10411036
A = triu(rand(elty, m, m))
10421037
B = rand(elty, m,m)
@@ -1051,7 +1046,6 @@ end
10511046
@test C Array(dC)
10521047
end
10531048
end
1054-
end
10551049

10561050
let A = rand(elty, m,m), B = triu(rand(elty, m, m)), alpha = rand(elty)
10571051
dA = CuArray(A)
@@ -1081,7 +1075,6 @@ end
10811075
end
10821076
end
10831077

1084-
if VERSION != v"1.11.2" # https://github.com/JuliaLang/julia/pull/55764
10851078
@testset "triangular rdiv!" begin
10861079
A = rand(elty, m,m)
10871080
B = triu(rand(elty, m, m))
@@ -1096,7 +1089,6 @@ end
10961089
@test C Array(dC)
10971090
end
10981091
end
1099-
end
11001092

11011093
@testset "Diagonal rdiv!" begin
11021094
A = rand(elty, m,m)

0 commit comments

Comments
 (0)