Skip to content

Commit 7429b87

Browse files
committed
Fallback to Float64 for MatMulMode{:fast} with other AbstractFloat
1 parent fa2d6dd commit 7429b87

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

src/matmul.jl

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,18 @@ function __mul(A::AbstractMatrix{Interval{T}}, B::AbstractVecOrMat{Interval{T}})
438438
U = mA; U .= _add_round.(abs.(mA), rA, RoundUp)
439439
V = mB; V .= _add_round.(abs.(mB), rB, RoundUp)
440440

441-
cache_3 = zeros(T, size(A, 1), size(B, 2))
442-
rC = _call_gem_openblas_upward!(cache_3, U, V)
441+
cache_3 = zeros(Float64, size(A, 1), size(B, 2))
442+
rC = T.(_call_gem_openblas_upward!(cache_3, _to_stride_64(U), _to_stride_64(V)), RoundUp)
443443
rC .= _add_round.(_sub_round.(rC, μ, RoundUp), 2 .* γ, RoundUp)
444444

445445
return mC, rC
446446
end
447447

448+
_to_stride_64(A::StridedArray{Float64}) = A
449+
_to_stride_64(A::StridedArray{<:AbstractFloat}) = Float64.(A, RoundUp)
450+
_to_stride_64(A::AbstractVector) = _to_stride_64(Vector(A))
451+
_to_stride_64(A::AbstractMatrix) = _to_stride_64(Matrix(A))
452+
448453
function _vec_or_mat_midradius(A::AbstractVecOrMat{Interval{T}}) where {T<:AbstractFloat}
449454
mA = _div_round.(_add_round.(inf.(A), sup.(A), RoundUp), convert(T, 2), RoundUp)
450455
rA = _sub_round.(mA, inf.(A), RoundUp)
@@ -492,14 +497,7 @@ else
492497
_getrounding() = ccall(:fegetround, Cint, ())
493498
end
494499

495-
_2stride(A::StridedArray) = A
496-
_2stride(A::AbstractVector) = Vector(A)
497-
_2stride(A::AbstractMatrix) = Matrix(A)
498-
499-
function _call_gem_openblas_upward!(C, A_::AbstractMatrix{Float64}, B_::AbstractMatrix{Float64})
500-
A = _2stride(A_)
501-
B = _2stride(B_)
502-
500+
function _call_gem_openblas_upward!(C, A::AbstractMatrix, B::AbstractMatrix)
503501
m, k = size(A)
504502
n = size(B, 2)
505503

@@ -528,10 +526,7 @@ function _call_gem_openblas_upward!(C, A_::AbstractMatrix{Float64}, B_::Abstract
528526
end
529527
end
530528

531-
function _call_gem_openblas_upward!(C, A_::AbstractMatrix{Float64}, B_::AbstractVector{Float64})
532-
A = _2stride(A_)
533-
B = _2stride(B_)
534-
529+
function _call_gem_openblas_upward!(C, A::AbstractMatrix, B::AbstractVector)
535530
m, k = size(A)
536531

537532
α = 1.0

0 commit comments

Comments
 (0)