Skip to content

Commit f560b63

Browse files
authored
Merge pull request #441 from schmrlng/fix438
Cast expm Pade approximant coefficients to result eltype, fix #438
2 parents 7ddb8e4 + 24ca5c1 commit f560b63

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/expm.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,48 +68,50 @@ end
6868

6969
# Adapted from implementation in Base; algorithm from
7070
# Higham, "Functions of Matrices: Theory and Computation", SIAM, 2008
71-
function _exp(::Size, A::StaticMatrix{<:Any,<:Any,T}) where T
71+
function _exp(::Size, _A::StaticMatrix{<:Any,<:Any,T}) where T
72+
S = typeof((zero(T)*zero(T) + zero(T)*zero(T))/one(T))
73+
A = S.(_A)
7274
# omitted: matrix balancing, i.e., LAPACK.gebal!
7375
nA = maximum(sum(abs.(A), Val{1})) # marginally more performant than norm(A, 1)
7476
## For sufficiently small nA, use lower order Padé-Approximations
7577
if (nA <= 2.1)
7678
A2 = A*A
7779
if nA > 0.95
78-
U = @evalpoly(A2, T(8821612800)*I, T(302702400)*I, T(2162160)*I, T(3960)*I, T(1)*I)
80+
U = @evalpoly(A2, S(8821612800)*I, S(302702400)*I, S(2162160)*I, S(3960)*I, S(1)*I)
7981
U = A*U
80-
V = @evalpoly(A2, T(17643225600)*I, T(2075673600)*I, T(30270240)*I, T(110880)*I, T(90)*I)
82+
V = @evalpoly(A2, S(17643225600)*I, S(2075673600)*I, S(30270240)*I, S(110880)*I, S(90)*I)
8183
elseif nA > 0.25
82-
U = @evalpoly(A2, T(8648640)*I, T(277200)*I, T(1512)*I, T(1)*I)
84+
U = @evalpoly(A2, S(8648640)*I, S(277200)*I, S(1512)*I, S(1)*I)
8385
U = A*U
84-
V = @evalpoly(A2, T(17297280)*I, T(1995840)*I, T(25200)*I, T(56)*I)
86+
V = @evalpoly(A2, S(17297280)*I, S(1995840)*I, S(25200)*I, S(56)*I)
8587
elseif nA > 0.015
86-
U = @evalpoly(A2, T(15120)*I, T(420)*I, T(1)*I)
88+
U = @evalpoly(A2, S(15120)*I, S(420)*I, S(1)*I)
8789
U = A*U
88-
V = @evalpoly(A2, T(30240)*I, T(3360)*I, T(30)*I)
90+
V = @evalpoly(A2, S(30240)*I, S(3360)*I, S(30)*I)
8991
else
90-
U = @evalpoly(A2, T(60)*I, T(1)*I)
92+
U = @evalpoly(A2, S(60)*I, S(1)*I)
9193
U = A*U
92-
V = @evalpoly(A2, T(120)*I, T(12)*I)
94+
V = @evalpoly(A2, S(120)*I, S(12)*I)
9395
end
9496
expA = (V - U) \ (V + U)
9597
else
9698
s = log2(nA/5.4) # power of 2 later reversed by squaring
9799
if s > 0
98100
si = ceil(Int,s)
99-
A = A / T(2^si)
101+
A = A / S(2^si)
100102
end
101103

102104
A2 = A*A
103105
A4 = A2*A2
104106
A6 = A2*A4
105107

106-
U = A6*(T(1)*A6 + T(16380)*A4 + T(40840800)*A2) +
107-
(T(33522128640)*A6 + T(10559470521600)*A4 + T(1187353796428800)*A2) +
108-
T(32382376266240000)*I
108+
U = A6*(S(1)*A6 + S(16380)*A4 + S(40840800)*A2) +
109+
(S(33522128640)*A6 + S(10559470521600)*A4 + S(1187353796428800)*A2) +
110+
S(32382376266240000)*I
109111
U = A*U
110-
V = A6*(T(182)*A6 + T(960960)*A4 + T(1323241920)*A2) +
111-
(T(670442572800)*A6 + T(129060195264000)*A4 + T(7771770303897600)*A2) +
112-
T(64764752532480000)*I
112+
V = A6*(S(182)*A6 + S(960960)*A4 + S(1323241920)*A2) +
113+
(S(670442572800)*A6 + S(129060195264000)*A4 + S(7771770303897600)*A2) +
114+
S(64764752532480000)*I
113115
expA = (V - U) \ (V + U)
114116

115117
if s > 0 # squaring to reverse dividing by power of 2

0 commit comments

Comments
 (0)