Skip to content

Commit 6adeb1d

Browse files
wip
1 parent 3e89d94 commit 6adeb1d

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

ext/cuda/data_layouts_threadblock.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,26 @@ end
213213
@inline columnwise_is_valid_index(I::CI5, us::UniversalSize) =
214214
1 I[5] DataLayouts.get_Nh(us)
215215

216+
@inline function columnwise_linear_partition(
217+
us::DataLayouts.UniversalSize,
218+
n_max_threads::Integer,
219+
)
220+
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
221+
nitems = prod((Nij, Nij, Nh))
222+
threads = min(nitems, n_max_threads)
223+
blocks = cld(nitems, threads)
224+
return (; threads, blocks)
225+
end
226+
@inline function columnwise_linear_universal_index(us)
227+
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
228+
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
229+
n = (Nij, Nij, Nh)
230+
CI = CartesianIndices(map(x -> Base.OneTo(x), n))
231+
return (CI, i)
232+
end
233+
@inline columnwise_linear_is_valid_index(i_linear::Integer, N::Integer) =
234+
1 i_linear N
235+
216236
##### Element-wise (e.g., limiters)
217237
# TODO
218238

@@ -223,16 +243,27 @@ end
223243
Nnames,
224244
)
225245
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
226-
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
227-
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
246+
# @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
247+
# return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
248+
nitems = prod((Nh, Nij, Nij, Nnames))
249+
threads = min(nitems, n_max_threads)
250+
blocks = cld(nitems, threads)
251+
return (; threads, blocks)
228252
end
229-
@inline function multiple_field_solve_universal_index(us::UniversalSize)
230-
(i, j, iname) = CUDA.threadIdx()
231-
(h,) = CUDA.blockIdx()
232-
return (CartesianIndex((i, j, 1, 1, h)), iname)
253+
@inline function multiple_field_solve_universal_index(us::DataLayouts.UniversalSize, ::Val{Nnames}) where {Nnames}
254+
# (i, j, iname) = CUDA.threadIdx()
255+
# (h,) = CUDA.blockIdx()
256+
# return (CartesianIndex((i, j, 1, 1, h)), iname)
257+
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
258+
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
259+
n = (Nij, Nij, Nh, Nnames)
260+
CI = CartesianIndices(n)
261+
return (CI, i)
233262
end
234-
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
235-
1 I[5] DataLayouts.get_Nh(us)
263+
# @inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
264+
# 1 ≤ I[5] ≤ DataLayouts.get_Nh(us)
265+
@inline multiple_field_solve_is_valid_index(i_linear::Integer, N::Integer) =
266+
1 i_linear N
236267

237268
##### spectral kernel partition
238269
@inline function spectral_partition(

ext/cuda/matrix_fields_multiple_field_solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ function multiple_field_solve_kernel!(
8989
::Val{Nnames},
9090
) where {Nnames}
9191
@inbounds begin
92-
(I, iname) = multiple_field_solve_universal_index(us)
93-
if multiple_field_solve_is_valid_index(I, us)
94-
(i, j, _, _, h) = I.I
92+
(CI, i_linear) = multiple_field_solve_universal_index(us, Val(Nnames))
93+
if multiple_field_solve_is_valid_index(i_linear, prod(CI.I))
94+
(i, j, _, _, h, iname) = CI.I
9595
generated_single_field_solve!(
9696
device,
9797
caches,

0 commit comments

Comments
 (0)