@@ -12,85 +12,83 @@ Base.@propagate_inbounds function fd_operator_shmem(
12
12
# allocate temp output
13
13
RT = return_eltype (op, args... )
14
14
Ju³ = CUDA. CuStaticSharedArray (RT, (Nvt,))
15
- return Ju³
15
+ lJu³ = CUDA. CuStaticSharedArray (RT, (1 ,))
16
+ rJu³ = CUDA. CuStaticSharedArray (RT, (1 ,))
17
+ return (Ju³, lJu³, rJu³)
16
18
end
17
19
18
- Base. @propagate_inbounds function fd_operator_fill_shmem_interior ! (
20
+ Base. @propagate_inbounds function fd_operator_fill_shmem ! (
19
21
op:: Operators.DivergenceF2C ,
20
- Ju³,
21
- loc, # can be any location
22
- space,
23
- idx:: Utilities.PlusHalf ,
24
- hidx,
25
- arg,
26
- )
27
- @inbounds begin
28
- vt = threadIdx (). x
29
- lg = Geometry. LocalGeometry (space, idx, hidx)
30
- u³ = Operators. getidx (space, arg, loc, idx, hidx)
31
- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
32
- end
33
- return nothing
34
- end
35
-
36
- Base. @propagate_inbounds function fd_operator_fill_shmem_left_boundary! (
37
- op:: Operators.DivergenceF2C ,
38
- bc:: Operators.SetValue ,
39
- Ju³,
22
+ (Ju³, lJu³, rJu³),
40
23
loc,
24
+ bc_bds,
25
+ arg_space,
41
26
space,
42
27
idx:: Utilities.PlusHalf ,
43
28
hidx,
44
29
arg,
45
30
)
46
- idx == Operators. left_face_boundary_idx (space) ||
47
- error (" Incorrect left idx" )
48
31
@inbounds begin
49
32
vt = threadIdx (). x
50
33
lg = Geometry. LocalGeometry (space, idx, hidx)
51
- u³ = Operators. getidx (space, bc. val, loc, nothing , hidx)
52
- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
53
- end
54
- return nothing
55
- end
56
-
57
- Base. @propagate_inbounds function fd_operator_fill_shmem_right_boundary! (
58
- op:: Operators.DivergenceF2C ,
59
- bc:: Operators.SetValue ,
60
- Ju³,
61
- loc,
62
- space,
63
- idx:: Utilities.PlusHalf ,
64
- hidx,
65
- arg,
66
- )
67
- # The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
68
- idx == Operators. right_face_boundary_idx (space) ||
69
- error (" Incorrect right idx" )
70
- @inbounds begin
71
- vt = threadIdx (). x
72
- lg = Geometry. LocalGeometry (space, idx, hidx)
73
- u³ = Operators. getidx (space, bc. val, loc, nothing , hidx)
74
- Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
34
+ if ! on_boundary (space, op, loc, idx)
35
+ u³ = Operators. getidx (space, arg, loc, idx, hidx)
36
+ Ju³[vt] = Geometry. Jcontravariant3 (u³, lg)
37
+ else
38
+ bc = Operators. get_boundary (op, loc)
39
+ ub = Operators. getidx (space, bc. val, loc, nothing , hidx)
40
+ bJu³ = on_left_boundary (idx, space) ? lJu³ : rJu³
41
+ if bc isa Operators. SetValue
42
+ bJu³[1 ] = Geometry. Jcontravariant3 (ub, lg)
43
+ elseif bc isa Operators. SetDivergence
44
+ bJu³[1 ] = ub
45
+ elseif bc isa Operators. Extrapolate # no shmem needed
46
+ end
47
+ end
75
48
end
76
49
return nothing
77
50
end
78
51
79
52
Base. @propagate_inbounds function fd_operator_evaluate (
80
53
op:: Operators.DivergenceF2C ,
81
- Ju³,
54
+ ( Ju³, lJu³, rJu³) ,
82
55
loc,
83
56
space,
84
57
idx:: Integer ,
85
58
hidx,
86
- args ... ,
59
+ arg ,
87
60
)
88
61
@inbounds begin
89
62
vt = threadIdx (). x
90
- local_geometry = Geometry. LocalGeometry (space, idx, hidx)
91
- Ju³₋ = Ju³[vt] # corresponds to idx - half
92
- Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
93
- return (Ju³₊ ⊟ Ju³₋) ⊠ local_geometry. invJ
63
+ lg = Geometry. LocalGeometry (space, idx, hidx)
64
+ if ! on_boundary (space, op, loc, idx)
65
+ Ju³₋ = Ju³[vt] # corresponds to idx - half
66
+ Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
67
+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
68
+ else
69
+ bc = Operators. get_boundary (op, loc)
70
+ @assert bc isa Operators. SetValue || bc isa Operators. SetDivergence
71
+ if on_left_boundary (idx, space)
72
+ if bc isa Operators. SetValue
73
+ Ju³₋ = lJu³[1 ] # corresponds to idx - half
74
+ Ju³₊ = Ju³[vt + 1 ] # corresponds to idx + half
75
+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
76
+ else
77
+ # @assert bc isa Operators.SetDivergence
78
+ return lJu³[1 ]
79
+ end
80
+ else
81
+ @assert on_right_boundary (idx, space)
82
+ if bc isa Operators. SetValue
83
+ Ju³₋ = Ju³[vt] # corresponds to idx - half
84
+ Ju³₊ = rJu³[1 ] # corresponds to idx + half
85
+ return (Ju³₊ ⊟ Ju³₋) ⊠ lg. invJ
86
+ else
87
+ @assert bc isa Operators. SetDivergence
88
+ return rJu³[1 ]
89
+ end
90
+ end
91
+ end
94
92
end
95
93
end
96
94
@@ -108,10 +106,12 @@ Base.@propagate_inbounds function fd_operator_shmem(
108
106
return (u, lb, rb)
109
107
end
110
108
111
- Base. @propagate_inbounds function fd_operator_fill_shmem_interior ! (
109
+ Base. @propagate_inbounds function fd_operator_fill_shmem ! (
112
110
op:: Operators.GradientC2F ,
113
111
(u, lb, rb),
114
112
loc, # can be any location
113
+ bc_bds,
114
+ arg_space,
115
115
space,
116
116
idx:: Integer ,
117
117
hidx,
@@ -120,50 +120,33 @@ Base.@propagate_inbounds function fd_operator_fill_shmem_interior!(
120
120
@inbounds begin
121
121
vt = threadIdx (). x
122
122
cov3 = Geometry. Covariant3Vector (1 )
123
- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
124
- end
125
- return nothing
126
- end
127
-
128
- Base. @propagate_inbounds function fd_operator_fill_shmem_left_boundary! (
129
- op:: Operators.GradientC2F ,
130
- bc:: Operators.SetValue ,
131
- (u, lb, rb),
132
- loc,
133
- space,
134
- idx:: Integer ,
135
- hidx,
136
- arg,
137
- )
138
- idx == Operators. left_center_boundary_idx (space) ||
139
- error (" Incorrect left idx" )
140
- @inbounds begin
141
- vt = threadIdx (). x
142
- cov3 = Geometry. Covariant3Vector (1 )
143
- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
144
- lb[1 ] = cov3 ⊗ Operators. getidx (space, bc. val, loc, nothing , hidx)
145
- end
146
- return nothing
147
- end
148
-
149
- Base. @propagate_inbounds function fd_operator_fill_shmem_right_boundary! (
150
- op:: Operators.GradientC2F ,
151
- bc:: Operators.SetValue ,
152
- (u, lb, rb),
153
- loc,
154
- space,
155
- idx:: Integer ,
156
- hidx,
157
- arg,
158
- )
159
- # The right boundary is called at `idx + 1`, so we need to subtract 1 from idx (shmem is loaded at vt+1)
160
- idx == Operators. right_center_boundary_idx (space) ||
161
- error (" Incorrect right idx" )
162
- @inbounds begin
163
- vt = threadIdx (). x
164
- cov3 = Geometry. Covariant3Vector (1 )
165
- u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
166
- rb[1 ] = cov3 ⊗ Operators. getidx (space, bc. val, loc, nothing , hidx)
123
+ if in_domain (idx, arg_space)
124
+ u[vt] = cov3 ⊗ Operators. getidx (space, arg, loc, idx, hidx)
125
+ else # idx can be Spaces.nlevels(ᶜspace)+1 because threads must extend to faces
126
+ ᶜspace = Spaces. center_space (arg_space)
127
+ @assert idx == Spaces. nlevels (ᶜspace) + 1
128
+ end
129
+ if on_any_boundary (idx, space, op)
130
+ lloc =
131
+ Operators. LeftBoundaryWindow {Spaces.left_boundary_name(space)} ()
132
+ rloc = Operators. RightBoundaryWindow{
133
+ Spaces. right_boundary_name (space),
134
+ }()
135
+ bloc = on_left_boundary (idx, space, op) ? lloc : rloc
136
+ @assert bloc isa typeof (lloc) && on_left_boundary (idx, space, op) ||
137
+ bloc isa typeof (rloc) && on_right_boundary (idx, space, op)
138
+ bc = Operators. get_boundary (op, bloc)
139
+ @assert bc isa Operators. SetValue || bc isa Operators. SetGradient
140
+ ub = Operators. getidx (space, bc. val, bloc, nothing , hidx)
141
+ bu = on_left_boundary (idx, space) ? lb : rb
142
+ if bc isa Operators. SetValue
143
+ bu[1 ] = cov3 ⊗ ub
144
+ elseif bc isa Operators. SetGradient
145
+ lg = Geometry. LocalGeometry (space, idx, hidx)
146
+ bu[1 ] = Geometry. project (Geometry. Covariant3Axis (), ub, lg)
147
+ elseif bc isa Operators. Extrapolate # no shmem needed
148
+ end
149
+ end
167
150
end
168
151
return nothing
169
152
end
@@ -179,17 +162,28 @@ Base.@propagate_inbounds function fd_operator_evaluate(
179
162
)
180
163
@inbounds begin
181
164
vt = threadIdx (). x
182
- # @assert idx.i == vt-1 # assertion passes, but commented to remove potential thrown exception in llvm output
183
- if idx == Operators. right_face_boundary_idx (space)
184
- u₋ = 2 * u[vt - 1 ] # corresponds to idx - half
185
- u₊ = 2 * rb[1 ] # corresponds to idx + half
186
- elseif idx == Operators. left_face_boundary_idx (space)
187
- u₋ = 2 * lb[1 ] # corresponds to idx - half
188
- u₊ = 2 * u[vt] # corresponds to idx + half
189
- else
165
+ lg = Geometry. LocalGeometry (space, idx, hidx)
166
+ if ! on_boundary (space, op, loc, idx)
190
167
u₋ = u[vt - 1 ] # corresponds to idx - half
191
168
u₊ = u[vt] # corresponds to idx + half
169
+ return u₊ ⊟ u₋
170
+ else
171
+ bc = Operators. get_boundary (op, loc)
172
+ @assert bc isa Operators. SetValue
173
+ if on_left_boundary (idx, space)
174
+ if bc isa Operators. SetValue
175
+ u₋ = 2 * lb[1 ] # corresponds to idx - half
176
+ u₊ = 2 * u[vt] # corresponds to idx + half
177
+ return u₊ ⊟ u₋
178
+ end
179
+ else
180
+ @assert on_right_boundary (idx, space)
181
+ if bc isa Operators. SetValue
182
+ u₋ = 2 * u[vt - 1 ] # corresponds to idx - half
183
+ u₊ = 2 * rb[1 ] # corresponds to idx + half
184
+ return u₊ ⊟ u₋
185
+ end
186
+ end
192
187
end
193
- return u₊ ⊟ u₋
194
188
end
195
189
end
0 commit comments