@@ -287,6 +287,7 @@ mutable struct RadauIIA9ConstantCache{F, Tab, Tol, Dt, U, JType} <:
287
287
cont2:: U
288
288
cont3:: U
289
289
cont4:: U
290
+ cont5:: U
290
291
dtprev:: Dt
291
292
W_γdt:: Dt
292
293
status:: NLStatus
@@ -304,7 +305,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
304
305
κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
305
306
J = false .* _vec (rate_prototype) .* _vec (rate_prototype)'
306
307
307
- RadauIIA9ConstantCache (uf, tab, κ, one (uToltype), 10000 , u, u, u, u, dt, dt,
308
+ RadauIIA9ConstantCache (uf, tab, κ, one (uToltype), 10000 , u, u, u, u, u, dt, dt,
308
309
Convergence, J)
309
310
end
310
311
@@ -333,6 +334,7 @@ mutable struct RadauIIA9Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
333
334
cont2:: uType
334
335
cont3:: uType
335
336
cont4:: uType
337
+ cont5:: uType
336
338
du1:: rateType
337
339
fsalfirst:: rateType
338
340
k:: rateType
@@ -407,6 +409,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
407
409
cont2 = zero (u)
408
410
cont3 = zero (u)
409
411
cont4 = zero (u)
412
+ cont5 = zero (u)
410
413
411
414
fsalfirst = zero (rate_prototype)
412
415
k = zero (rate_prototype)
@@ -462,11 +465,193 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
462
465
463
466
RadauIIA9Cache (u, uprev,
464
467
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
465
- dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
468
+ dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4, cont5,
466
469
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
467
470
J, W1, W2, W3,
468
471
uf, tab, κ, one (uToltype), 10000 ,
469
472
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config,
470
473
linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
471
474
Convergence, alg. step_limiter!)
472
475
end
476
+
477
+ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} < :
478
+ OrdinaryDiffEqConstantCache
479
+ uf:: F
480
+ tab:: Tab
481
+ κ:: Tol
482
+ ηold:: Tol
483
+ iter:: Int
484
+ cont:: Vector{U}
485
+ dtprev:: Dt
486
+ W_γdt:: Dt
487
+ status:: NLStatus
488
+ J:: JType
489
+ end
490
+
491
+ function alg_cache (alg:: AdaptiveRadau , u, rate_prototype, :: Type{uEltypeNoUnits} ,
492
+ :: Type{uBottomEltypeNoUnits} ,
493
+ :: Type{tTypeNoUnits} , uprev, uprev2, f, t, dt, reltol, p, calck,
494
+ :: Val{false} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
495
+ uf = UDerivativeWrapper (f, t, p)
496
+ uToltype = constvalue (uBottomEltypeNoUnits)
497
+ num_stages = alg. num_stages
498
+
499
+ if (num_stages == 3 )
500
+ tab = BigRadauIIA5Tableau (uToltype, constvalue (tTypeNoUnits))
501
+ elseif (num_stages == 5 )
502
+ tab = BigRadauIIA9Tableau (uToltype, constvalue (tTypeNoUnits))
503
+ elseif (num_stages == 7 )
504
+ tab = BigRadauIIA13Tableau (uToltype, constvalue (tTypeNoUnits))
505
+ elseif iseven (num_stages) || num_stages < 3
506
+ error (" num_stages must be odd and 3 or greater" )
507
+ else
508
+ tab = adaptiveRadauTableau (uToltype, constvalue (tTypeNoUnits), num_stages)
509
+ end
510
+
511
+ cont = Vector {typeof(u)} (undef, num_stages)
512
+ for i in 1 : num_stages
513
+ cont[i] = zero (u)
514
+ end
515
+
516
+ κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
517
+ J = false .* _vec (rate_prototype) .* _vec (rate_prototype)'
518
+
519
+ AdaptiveRadauConstantCache (uf, tab, κ, one (uToltype), 10000 , cont, dt, dt,
520
+ Convergence, J)
521
+ end
522
+
523
+ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
524
+ UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} < :
525
+ FIRKMutableCache
526
+ u:: uType
527
+ uprev:: uType
528
+ z:: Vector{uType}
529
+ w:: Vector{uType}
530
+ c_prime:: Vector{tType}
531
+ dw1:: uType
532
+ ubuff:: uType
533
+ dw2:: Vector{cuType}
534
+ cubuff:: Vector{cuType}
535
+ dw:: Vector{uType}
536
+ cont:: Vector{uType}
537
+ derivatives:: Matrix{uType}
538
+ du1:: rateType
539
+ fsalfirst:: rateType
540
+ ks:: Vector{rateType}
541
+ k:: rateType
542
+ fw:: Vector{rateType}
543
+ J:: JType
544
+ W1:: W1Type # real
545
+ W2:: Vector{W2Type} # complex
546
+ uf:: UF
547
+ tab:: Tab
548
+ κ:: Tol
549
+ ηold:: Tol
550
+ iter:: Int
551
+ tmp:: uType
552
+ atmp:: uNoUnitsType
553
+ jac_config:: JC
554
+ linsolve1:: F1 # real
555
+ linsolve2:: Vector{F2} # complex
556
+ rtol:: rTol
557
+ atol:: aTol
558
+ dtprev:: Dt
559
+ W_γdt:: Dt
560
+ status:: NLStatus
561
+ step_limiter!:: StepLimiter
562
+ end
563
+
564
+ function alg_cache (alg:: AdaptiveRadau , u, rate_prototype, :: Type{uEltypeNoUnits} ,
565
+ :: Type{uBottomEltypeNoUnits} ,
566
+ :: Type{tTypeNoUnits} , uprev, uprev2, f, t, dt, reltol, p, calck,
567
+ :: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
568
+ uf = UJacobianWrapper (f, t, p)
569
+ uToltype = constvalue (uBottomEltypeNoUnits)
570
+ num_stages = alg. num_stages
571
+
572
+ if (num_stages == 3 )
573
+ tab = BigRadauIIA5Tableau (uToltype, constvalue (tTypeNoUnits))
574
+ elseif (num_stages == 5 )
575
+ tab = BigRadauIIA9Tableau (uToltype, constvalue (tTypeNoUnits))
576
+ elseif (num_stages == 7 )
577
+ tab = BigRadauIIA13Tableau (uToltype, constvalue (tTypeNoUnits))
578
+ elseif iseven (num_stages) || num_stages < 3
579
+ error (" num_stages must be odd and 3 or greater" )
580
+ else
581
+ tab = adaptiveRadauTableau (uToltype, constvalue (tTypeNoUnits), num_stages)
582
+ end
583
+
584
+ κ = alg. κ != = nothing ? convert (uToltype, alg. κ) : convert (uToltype, 1 // 100 )
585
+
586
+ z = Vector {typeof(u)} (undef, num_stages)
587
+ w = Vector {typeof(u)} (undef, num_stages)
588
+ for i in 1 : num_stages
589
+ z[i] = w[i] = zero (u)
590
+ end
591
+
592
+ c_prime = Vector {typeof(t)} (undef, num_stages) # time stepping
593
+
594
+ dw1 = zero (u)
595
+ ubuff = zero (u)
596
+ dw2 = [similar (u, Complex{eltype (u)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
597
+ recursivefill! .(dw2, false )
598
+ cubuff = [similar (u, Complex{eltype (u)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
599
+ recursivefill! .(cubuff, false )
600
+ dw = Vector {typeof(u)} (undef, num_stages - 1 )
601
+
602
+ cont = Vector {typeof(u)} (undef, num_stages)
603
+ for i in 1 : num_stages
604
+ cont[i] = zero (u)
605
+ end
606
+
607
+ derivatives = Matrix {typeof(u)} (undef, num_stages, num_stages)
608
+ for i in 1 : num_stages, j in 1 : num_stages
609
+ derivatives[i, j] = zero (u)
610
+ end
611
+
612
+ fsalfirst = zero (rate_prototype)
613
+ fw = Vector {typeof(rate_prototype)} (undef, num_stages)
614
+ ks = Vector {typeof(rate_prototype)} (undef, num_stages)
615
+ for i in 1 : num_stages
616
+ ks[i] = fw[i] = zero (rate_prototype)
617
+ end
618
+ k = ks[1 ]
619
+
620
+ J, W1 = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (true ))
621
+ if J isa AbstractSciMLOperator
622
+ error (" Non-concrete Jacobian not yet supported by AdaptiveRadau." )
623
+ end
624
+
625
+ W2 = [similar (J, Complex{eltype (W1)}) for _ in 1 : (num_stages - 1 ) ÷ 2 ]
626
+ recursivefill! .(W2, false )
627
+
628
+ du1 = zero (rate_prototype)
629
+
630
+ tmp = zero (u)
631
+
632
+ atmp = similar (u, uEltypeNoUnits)
633
+ recursivefill! (atmp, false )
634
+
635
+ jac_config = build_jac_config (alg, f, uf, du1, uprev, u, zero (u), dw1)
636
+
637
+ linprob = LinearProblem (W1, _vec (ubuff); u0 = _vec (dw1))
638
+ linsolve1 = init (linprob, alg. linsolve, alias_A = true , alias_b = true ,
639
+ assumptions = LinearSolve. OperatorAssumptions (true ))
640
+
641
+ linsolve2 = [
642
+ init (LinearProblem (W2[i], _vec (cubuff[i]); u0 = _vec (dw2[i])), alg. linsolve, alias_A = true , alias_b = true ,
643
+ assumptions = LinearSolve. OperatorAssumptions (true )) for i in 1 : (num_stages - 1 ) ÷ 2 ]
644
+
645
+ rtol = reltol isa Number ? reltol : zero (reltol)
646
+ atol = reltol isa Number ? reltol : zero (reltol)
647
+
648
+ AdaptiveRadauCache (u, uprev,
649
+ z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
650
+ du1, fsalfirst, ks, k, fw,
651
+ J, W1, W2,
652
+ uf, tab, κ, one (uToltype), 10000 , tmp,
653
+ atmp, jac_config,
654
+ linsolve1, linsolve2, rtol, atol, dt, dt,
655
+ Convergence, alg. step_limiter!)
656
+ end
657
+
0 commit comments