Skip to content

Commit c7b8a80

Browse files
Merge pull request #2542 from mcarmesin/master
Add tests for LinearExponential() on GPU
2 parents b37a5cc + ee8ea06 commit c7b8a80

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

lib/OrdinaryDiffEqLinear/src/linear_caches.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,9 @@ function _phiv_timestep_caches(u_prototype, maxiter::Int, p::Int)
571571
n = length(u_prototype)
572572
T = eltype(u_prototype)
573573
u = zero(u_prototype) # stores the current state
574-
W = Matrix{T}(undef, n, p + 1) # stores the w vectors
575-
P = Matrix{T}(undef, n, p + 2) # stores output from phiv!
576-
Ks = KrylovSubspace{T}(n, maxiter) # stores output from arnoldi!
574+
W = similar(u_prototype, n, p+1) # stores the w vectors
575+
P = similar(u_prototype, n, p+2) # stores output from phiv!
576+
Ks = KrylovSubspace{T,T,typeof(similar(u_prototype,size(u_prototype,1),2))}(n, maxiter) # stores output from arnoldi!
577577
phiv_cache = PhivCache(u_prototype, maxiter, p + 1) # cache used by phiv! (need +1 for error estimation)
578578
return u, W, P, Ks, phiv_cache
579579
end

test/gpu/linear_exp.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using LinearAlgebra
2+
using SparseArrays
3+
using CUDA
4+
using CUDA.CUSPARSE
5+
using OrdinaryDiffEq
6+
7+
# Linear exponential solvers
8+
A = MatrixOperator([2.0 -1.0; -1.0 2.0])
9+
u0 = ones(2)
10+
11+
A_gpu = MatrixOperator(cu([2.0 -1.0; -1.0 2.0]))
12+
u0_gpu = cu(ones(2))
13+
prob_gpu = ODEProblem(A_gpu, u0_gpu, (0.0, 1.0))
14+
15+
sol_analytic = exp(1.0 * Matrix(A)) * u0
16+
17+
@test_broken sol1_gpu = solve(prob_gpu, LinearExponential(krylov = :off))(1.0) |> Vector
18+
sol2_gpu = solve(prob_gpu, LinearExponential(krylov = :simple))(1.0) |> Vector
19+
sol3_gpu = solve(prob_gpu, LinearExponential(krylov = :adaptive))(1.0) |> Vector
20+
21+
@test_broken isapprox(sol1_gpu, sol_analytic, rtol = 1e-6)
22+
@test isapprox(sol2_gpu, sol_analytic, rtol = 1e-6)
23+
@test isapprox(sol3_gpu, sol_analytic, rtol = 1e-6)
24+
25+
A2_gpu = MatrixOperator(cu(sparse([2.0 -1.0; -1.0 2.0])))
26+
prob2_gpu = ODEProblem(A2_gpu, u0_gpu, (0.0, 1.0))
27+
28+
@test_broken sol2_1_gpu = solve(prob2_gpu, LinearExponential(krylov = :off))(1.0) |> Vector
29+
sol2_2_gpu = solve(prob2_gpu, LinearExponential(krylov = :simple))(1.0) |> Vector
30+
sol2_3_gpu = solve(prob2_gpu, LinearExponential(krylov = :adaptive))(1.0) |> Vector
31+
32+
@test_broken isapprox(sol2_1_gpu, sol_analytic, rtol = 1e-6)
33+
@test isapprox(sol2_2_gpu, sol_analytic, rtol = 1e-6)
34+
@test isapprox(sol2_3_gpu, sol_analytic, rtol = 1e-6)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ end
172172
end
173173
@time @safetestset "Autoswitch GPU" include("gpu/autoswitch.jl")
174174
@time @safetestset "Linear LSRK GPU" include("gpu/linear_lsrk.jl")
175+
@time @safetestset "Linear Exponential GPU" include("gpu/linear_exp.jl")
175176
@time @safetestset "Reaction-Diffusion Stiff Solver GPU" include("gpu/reaction_diffusion_stiff.jl")
176177
@time @safetestset "Scalar indexing bug bypass" include("gpu/hermite_test.jl")
177178
end

0 commit comments

Comments
 (0)