Skip to content

Commit 95b200b

Browse files
maleadtsimsurace
andauthored
Revamp diag cholesky method (#444)
Co-authored-by: simsurace <simsurace@gmx.net> Co-authored-by: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com>
1 parent 829f433 commit 95b200b

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

src/host/linalg.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,29 +119,35 @@ end
119119

120120
Base.copy(D::Diagonal{T, <:AbstractGPUArray{T, N}}) where {T, N} = Diagonal(copy(D.diag))
121121

122+
_isrealandpositive(x) = isreal(x) && real(x) > 0
123+
122124
if VERSION <= v"1.8-"
123125
function LinearAlgebra.cholesky!(D::Diagonal{<:Any, <:AbstractGPUArray},
124126
::Val{false} = Val(false); check::Bool = true)
125-
info = 0
126-
if mapreduce(x -> isreal(x) && isposdef(x), &, D.diag)
127+
info = findfirst(!_isrealandpositive, D.diag)
128+
if isnothing(info)
127129
D.diag .= sqrt.(D.diag)
130+
info = 0
131+
elseif check
132+
throw(PosDefException(info))
128133
else
129-
info = findfirst(x -> !isreal(x) || !isposdef(x), collect(D.diag))
130-
check && throw(PosDefException(info))
134+
D.diag[begin:info-1] .= sqrt.(D.diag[begin:info-1])
131135
end
132-
Cholesky(D, 'U', convert(LinearAlgebra.BlasInt, info))
136+
return Cholesky(D, 'U', convert(LinearAlgebra.BlasInt, info))
133137
end
134138
else
135139
function LinearAlgebra.cholesky!(D::Diagonal{<:Any, <:AbstractGPUArray},
136140
::NoPivot = NoPivot(); check::Bool = true)
137-
info = 0
138-
if mapreduce(x -> isreal(x) && isposdef(x), &, D.diag)
141+
info = findfirst(!_isrealandpositive, D.diag)
142+
if isnothing(info)
139143
D.diag .= sqrt.(D.diag)
144+
info = 0
145+
elseif check
146+
throw(PosDefException(info))
140147
else
141-
info = findfirst(x -> !isreal(x) || !isposdef(x), collect(D.diag))
142-
check && throw(PosDefException(info))
148+
D.diag[begin:info-1] .= sqrt.(D.diag[begin:info-1])
143149
end
144-
Cholesky(D, 'U', convert(LinearAlgebra.BlasInt, info))
150+
return Cholesky(D, 'U', convert(LinearAlgebra.BlasInt, info))
145151
end
146152
end
147153

0 commit comments

Comments
 (0)