Skip to content

Commit 13a3557

Browse files
committed
Make a LinearAlgebra extension and use the improved fast matrix multiplication
1 parent 992d2f2 commit 13a3557

File tree

4 files changed

+338
-285
lines changed

4 files changed

+338
-285
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@ version = "0.22.23"
77
CRlibm = "96374032-68de-5a5b-8d9e-752f78720389"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
10+
OpenBLASConsistentFPCSR_jll = "6cdc7f73-28fd-5e50-80fb-958a8875b1af"
1011
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1112
RoundingEmulator = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705"
1213

1314
[weakdeps]
1415
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1516
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1617
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
18+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1719
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1820

1921
[extensions]
2022
IntervalArithmeticDiffRulesExt = "DiffRules"
2123
IntervalArithmeticForwardDiffExt = "ForwardDiff"
2224
IntervalArithmeticIntervalSetsExt = "IntervalSets"
25+
IntervalArithmeticLinearAlgebraExt = "LinearAlgebra"
2326
IntervalArithmeticRecipesBaseExt = "RecipesBase"
2427

2528
[compat]
@@ -29,6 +32,7 @@ ForwardDiff = "0.10"
2932
IntervalSets = "0.7"
3033
LinearAlgebra = "1.10"
3134
MacroTools = "0.5"
35+
OpenBLASConsistentFPCSR_jll = "0.3.29"
3236
Printf = "1.10"
3337
RecipesBase = "1"
3438
RoundingEmulator = "0.2"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
module IntervalArithmeticLinearAlgebraExt
2+
3+
using IntervalArithmetic
4+
import LinearAlgebra
5+
6+
# contructor for `UniformScaling`
7+
8+
IntervalArithmetic.interval(::Type{T}, J::LinearAlgebra.UniformScaling, d::IntervalArithmetic.Decoration = com; format::Symbol = :infsup) where {T} =
9+
LinearAlgebra.UniformScaling(interval(T, J.λ, d; format = format))
10+
IntervalArithmetic.interval(J::LinearAlgebra.UniformScaling, d::IntervalArithmetic.Decoration = com; format::Symbol = :infsup) =
11+
LinearAlgebra.UniformScaling(interval(J.λ, d; format = format))
12+
13+
# by-pass generic `opnorm` from LinearAlgebra to prevent NG flag
14+
15+
function LinearAlgebra.opnorm1(A::AbstractMatrix{T}) where {T<:RealOrComplexI}
16+
LinearAlgebra.require_one_based_indexing(A)
17+
m, n = size(A)
18+
Tnorm = typeof(float(real(zero(T))))
19+
Tsum = promote_type(Float64, Tnorm)
20+
nrm = zero(Tsum)
21+
@inbounds begin
22+
for j = 1:n
23+
nrmj = zero(Tsum)
24+
for i = 1:m
25+
nrmj += LinearAlgebra.norm(A[i,j])
26+
end
27+
nrm = max(nrm, nrmj)
28+
end
29+
end
30+
return convert(Tnorm, nrm)
31+
end
32+
33+
function LinearAlgebra.opnormInf(A::AbstractMatrix{T}) where {T<:RealOrComplexI}
34+
LinearAlgebra.require_one_based_indexing(A)
35+
m, n = size(A)
36+
Tnorm = typeof(float(real(zero(T))))
37+
Tsum = promote_type(Float64, Tnorm)
38+
nrm = zero(Tsum)
39+
@inbounds begin
40+
for i = 1:m
41+
nrmi = zero(Tsum)
42+
for j = 1:n
43+
nrmi += LinearAlgebra.norm(A[i,j])
44+
end
45+
nrm = max(nrm, nrmi)
46+
end
47+
end
48+
return convert(Tnorm, nrm)
49+
end
50+
51+
# matrix eigenvalues
52+
53+
function LinearAlgebra.eigvals!(A::AbstractMatrix{<:Interval}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=LinearAlgebra.eigsortby)
54+
# note: this function does not overwrite `A`
55+
v = _eigvals(A, permute, scale, sortby)
56+
isreal(v) && return v
57+
_fold_conjugate!(v)
58+
isreal(v) && return real(v)
59+
return v
60+
end
61+
62+
LinearAlgebra.eigvals!(A::AbstractMatrix{<:Complex{<:Interval}}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=LinearAlgebra.eigsortby) =
63+
# note: this function does not overwrite `A`
64+
_eigvals(A, permute, scale, sortby)
65+
66+
function _eigvals(A, permute, scale, sortby)
67+
# Gershgorin circle theorem
68+
B = _similarity_transform(A, permute, scale, sortby)
69+
v = LinearAlgebra.diag(B)
70+
T = eltype(v)
71+
for j axes(B, 1)
72+
r = zero(T)
73+
for i axes(B, 2)
74+
if i j
75+
r += abs(B[i,j])
76+
end
77+
end
78+
v[j] = interval(v[j], r; format = :midpoint)
79+
end
80+
return v
81+
end
82+
83+
function _similarity_transform(A, permute, scale, sortby)
84+
mA = mid.(A)
85+
mλ, mV = LinearAlgebra.eigen(mA; permute = permute, scale = scale, sortby = sortby)
86+
.+= LinearAlgebra.diag(mV \ (mA * mV - mV * LinearAlgebra.Diagonal(mλ)))
87+
Λ = LinearAlgebra.Diagonal(interval(mλ))
88+
V = interval(mV)
89+
V .= Λ .+ inv(V) * (A * V - V * Λ)
90+
return V
91+
end
92+
93+
function _fold_conjugate!(v)
94+
for i eachindex(v)
95+
vᵢ = v[i]
96+
idxs = findall(j -> (j i) & !isdisjoint_interval(conj(vᵢ), v[j]), eachindex(v))
97+
if isempty(idxs)
98+
v[i] = real(vᵢ)
99+
else
100+
w = view(v, idxs)
101+
z = conj(intersect_interval(conj(vᵢ), reduce(intersect_interval, w)))
102+
z = complex(IntervalArithmetic.setdecoration(real(z), min(decoration(real(vᵢ)), minimum(decoration real, w))), IntervalArithmetic.setdecoration(imag(z), min(decoration(imag(vᵢ)), minimum(decoration imag, w))))
103+
v[i] = z
104+
end
105+
end
106+
return v
107+
end
108+
109+
# matrix determinant
110+
111+
LinearAlgebra.det(A::AbstractMatrix{<:Interval}) = real(reduce(*, LinearAlgebra.eigvals(A)))
112+
LinearAlgebra.det(A::AbstractMatrix{<:Complex{<:Interval}}) = reduce(*, LinearAlgebra.eigvals(A))
113+
114+
end

src/IntervalArithmetic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ include("display.jl")
8484
#
8585

8686
import LinearAlgebra
87+
import OpenBLASConsistentFPCSR_jll # 32-bit systems are not supported
88+
89+
if Int != Int32
90+
# use the same number of threads as the default BLAS library
91+
ccall((:openblas_set_num_threads64_, OpenBLASConsistentFPCSR_jll.libopenblas),
92+
Cint, (Cint,),
93+
LinearAlgebra.BLAS.get_num_threads())
94+
end
8795

8896
include("matmul.jl")
8997

0 commit comments

Comments
 (0)