16
16
import . TestUtilities as TU
17
17
18
18
import ClimaCore: Spaces, Geometry, Operators, Fields, MatrixFields
19
+ using LinearAlgebra: Adjoint
20
+ import StaticArrays: SArray
21
+ import ClimaCore. Geometry: AxisTensor, CovariantAxis, ContravariantAxis
19
22
using ClimaCore. MatrixFields:
23
+ BandMatrixRow,
24
+ DiagonalMatrixRow,
20
25
BidiagonalMatrixRow,
21
26
TridiagonalMatrixRow,
22
27
MultiplyColumnwiseBandMatrixField,
23
28
⋅
24
29
const C3 = Geometry. Covariant3Vector
25
- FT = Float64
30
+ const CT3 = Geometry. Contravariant3Vector
31
+ GFT = Float64
26
32
const ᶠgradᵥ = Operators. GradientC2F (
27
33
bottom = Operators. SetGradient (C3 (0 )),
28
34
top = Operators. SetGradient (C3 (0 )),
@@ -32,27 +38,77 @@ const ᶠgradᵥ_matrix = MatrixFields.operator_matrix(ᶠgradᵥ)
32
38
device = ClimaComms. device ()
33
39
context = ClimaComms. context (device)
34
40
cspace =
35
- TU. CenterExtrudedFiniteDifferenceSpace (FT ; zelem = 25 , helem = 10 , context)
41
+ TU. CenterExtrudedFiniteDifferenceSpace (GFT ; zelem = 25 , helem = 10 , context)
36
42
fspace = Spaces. FaceExtrudedFiniteDifferenceSpace (cspace)
37
43
@info " device = $device "
38
44
45
+ ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ_type = BandMatrixRow{
46
+ - 1 ,
47
+ 3 ,
48
+ AxisTensor{
49
+ GFT,
50
+ 2 ,
51
+ Tuple{CovariantAxis{(3 ,)}, ContravariantAxis{(3 ,)}},
52
+ SArray{Tuple{1 , 1 }, GFT, 2 , 1 },
53
+ },
54
+ }
55
+
39
56
f = (;
40
- ᶠtridiagonal_matrix_c3 = Fields. Field (TridiagonalMatrixRow{C3{FT}}, fspace),
57
+ ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ = Fields. Field (∂ᶠu₃ʲ_err_∂ᶠu₃ʲ_type, fspace),
58
+ ᶠtridiagonal_matrix_c3 = Fields. Field (
59
+ TridiagonalMatrixRow{C3{GFT}},
60
+ fspace,
61
+ ),
62
+ ᶠu₃ = Fields. Field (C3{GFT}, fspace),
63
+ adj_u₃ = Fields. Field (DiagonalMatrixRow{Adjoint{GFT, CT3{GFT}}}, fspace),
64
+ )
65
+ c = (;
66
+ ᶜu₃ʲ = Fields. Field (C3{GFT}, cspace),
67
+ bdmr_l = Fields. Field (BidiagonalMatrixRow{GFT}, cspace),
68
+ bdmr_r = Fields. Field (BidiagonalMatrixRow{GFT}, cspace),
69
+ bdmr = Fields. Field (BidiagonalMatrixRow{GFT}, cspace),
41
70
)
42
71
43
72
const ᶜleft_bias = Operators. LeftBiasedF2C ()
44
73
const ᶜright_bias = Operators. RightBiasedF2C ()
45
74
const ᶜleft_bias_matrix = MatrixFields. operator_matrix (ᶜleft_bias)
46
75
const ᶜright_bias_matrix = MatrixFields. operator_matrix (ᶜright_bias)
47
76
77
+ one_C3xACT3 (:: Type{_FT} ) where {_FT} = C3 (_FT (1 )) * CT3 (_FT (1 ))'
78
+ get_I_u₃ (:: Type{_FT} ) where {_FT} = DiagonalMatrixRow (one_C3xACT3 (_FT))
79
+
48
80
conv (:: Type{_FT} , ᶜbias_matrix) where {_FT} =
49
81
convert (BidiagonalMatrixRow{_FT}, ᶜbias_matrix)
50
- function foo (f)
51
- (; ᶠtridiagonal_matrix_c3) = f
82
+ function foo (c, f)
83
+ (; ᶠtridiagonal_matrix_c3, ᶠu₃, ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ, adj_u₃) = f
84
+ (; ᶜu₃ʲ, bdmr_l, bdmr_r, bdmr) = c
52
85
space = axes (ᶠtridiagonal_matrix_c3)
53
86
FT = Spaces. undertype (space)
54
- @. ᶠtridiagonal_matrix_c3 = ᶠgradᵥ_matrix () ⋅ conv (FT, ᶜleft_bias_matrix ())
87
+ I_u₃ = get_I_u₃ (FT)
88
+ dtγ = FT (1 )
89
+
90
+ @. ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ =
91
+ dtγ * ᶠtridiagonal_matrix_c3 ⋅ DiagonalMatrixRow (adjoint (CT3 (ᶠu₃))) -
92
+ (I_u₃,)
93
+
94
+ @. ∂ᶠu₃ʲ_err_∂ᶠu₃ʲ = dtγ * ᶠtridiagonal_matrix_c3 ⋅ adj_u₃ - (I_u₃,)
95
+
96
+ # Fails on gpu
97
+ @. ᶠtridiagonal_matrix_c3 =
98
+ - (ᶠgradᵥ_matrix ()) ⋅ ifelse (
99
+ ᶜu₃ʲ. components. data.:1 > 0 ,
100
+ convert (BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix ()),
101
+ convert (BidiagonalMatrixRow{FT}, ᶜright_bias_matrix ()),
102
+ )
103
+
104
+ # However, this can be decomposed into simpler broadcast
105
+ # expressions that will run on gpus:
106
+ @. bdmr_l = convert (BidiagonalMatrixRow{FT}, ᶜleft_bias_matrix ())
107
+ @. bdmr_r = convert (BidiagonalMatrixRow{FT}, ᶜright_bias_matrix ())
108
+ @. bdmr = ifelse (ᶜu₃ʲ. components. data.:1 > 0 , bdmr_l, bdmr_r)
109
+ @. ᶠtridiagonal_matrix_c3 = - (ᶠgradᵥ_matrix ()) ⋅ bdmr
110
+
55
111
return nothing
56
112
end
57
113
58
- foo (f)
114
+ foo (c, f)
0 commit comments