Skip to content

Commit dd7218d

Browse files
Fixes
1 parent b54dad5 commit dd7218d

File tree

6 files changed

+89
-28
lines changed

6 files changed

+89
-28
lines changed

ext/cuda/data_layouts.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,38 @@ function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
8787
return DL.FieldArray{FD}(arrays)
8888
end
8989

90-
DL.field_array(
91-
array::CUDA.CuArray,
92-
as::ArraySize
93-
) = CUDA.CuArray(DL.field_array(Array(array), as))
90+
DL.field_array(array::CUDA.CuArray, as::ArraySize) =
91+
CUDA.CuArray(DL.field_array(Array(array), as))
92+
93+
94+
# TODO: this could be improved, but it's not typically used at runtime
95+
function copyto_field_array_knl!(x::DL.FieldArray{FD}, y) where {FD}
96+
gidx =
97+
CUDA.threadIdx().x + (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x
98+
I = cart_ind(size(y), gidx)
99+
x[I] = y[I]
100+
return nothing
101+
end
102+
103+
@inline function Base.copyto!(
104+
x::DL.FieldArray{FD, NT},
105+
y::CUDA.CuArray,
106+
) where {FD, NT <: NTuple}
107+
if ndims(eltype(NT)) == ndims(y)
108+
@inbounds for i in 1:DL.tuple_length(NT)
109+
Base.copyto!(x.arrays[i], y)
110+
end
111+
elseif ndims(eltype(NT)) + 1 == ndims(y)
112+
n = prod(size(y))
113+
kernel =
114+
CUDA.@cuda always_inline = true launch = false copyto_field_array_knl!(
115+
x,
116+
y,
117+
)
118+
config = CUDA.launch_configuration(kernel.fun)
119+
threads = min(n, config.threads)
120+
blocks = cld(n, threads)
121+
kernel(x, y; threads, blocks)
122+
end
123+
x
124+
end

ext/cuda/data_layouts_copyto.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ function knl_copyto_linear!(dest::AbstractData, bc, us)
3030
return nothing
3131
end
3232

33-
function knl_copyto_linear!(dest::DataF, bc, us)
34-
tidx = thread_index()
35-
@inbounds dest[] = bc[tidx]
33+
function knl_copyto_linear!(dest::DataF{S},bc,us) where {S}
34+
@inbounds begin
35+
tidx = thread_index()
36+
if tidx get_N(us)
37+
dest[] = bc[tidx]
38+
end
39+
end
3640
return nothing
3741
end
3842

@@ -48,13 +52,17 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
4852
return nothing
4953
end
5054

51-
function knl_copyto_flat!(dest::DataF, bc, us)
55+
function knl_copyto_flat!(
56+
dest::DataF{S},
57+
bc::DataLayouts.BroadcastedUnionDataF{S},
58+
us,
59+
) where {S}
5260
@inbounds begin
5361
tidx = thread_index()
5462
if tidx get_N(us)
5563
n = size(dest)
56-
I = kernel_indexes(tidx, n)
57-
dest[] = bc[I]
64+
# I = kernel_indexes(tidx, n)
65+
dest[] = bc[]
5866
end
5967
end
6068
return nothing
@@ -67,12 +75,7 @@ function cuda_copyto!(dest::AbstractData, bc)
6775
if Nv > 0 && Nh > 0
6876
if has_uniform_datalayouts(bc)
6977
bc′ = to_non_extruded_broadcasted(bc)
70-
auto_launch!(
71-
knl_copyto_linear!,
72-
(dest, bc′, us),
73-
n;
74-
auto = true,
75-
)
78+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), n; auto = true)
7679
else
7780
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
7881
end

src/DataLayouts/DataLayouts.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ end
564564
@inbounds col[]
565565
end
566566

567+
@propagate_inbounds function Base.getindex(col::Data0D, I::Integer)
568+
@inbounds col[]
569+
end
570+
567571
Base.@propagate_inbounds function Base.setindex!(data::DataF{S}, val) where {S}
568572
@inbounds set_struct!(
569573
field_array(data),
@@ -581,6 +585,14 @@ end
581585
@inbounds col[] = val
582586
end
583587

588+
@propagate_inbounds function Base.setindex!(
589+
col::Data0D,
590+
val,
591+
I::Integer,
592+
)
593+
@inbounds col[] = val
594+
end
595+
584596
# ======================
585597
# DataSlab2D DataLayout
586598
# ======================
@@ -1305,15 +1317,30 @@ type parameters.
13051317

13061318
#! format: on
13071319

1308-
# Skip DataF here, since we want that to MethodError.
1309-
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH)
1310-
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
1311-
linear_getindex(data, I)
1312-
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
1313-
linear_setindex!(data, val, I)
1314-
end
1320+
@propagate_inbounds Base.setindex!(data::IJKFVH, val, I::Integer) = linear_setindex!(data, val, I)
1321+
@propagate_inbounds Base.setindex!(data::IJFH, val, I::Integer) = linear_setindex!(data, val, I)
1322+
@propagate_inbounds Base.setindex!(data::IFH, val, I::Integer) = linear_setindex!(data, val, I)
1323+
@propagate_inbounds Base.setindex!(data::DataF, val, I::Integer) = linear_setindex!(data, val, I)
1324+
@propagate_inbounds Base.setindex!(data::IJF, val, I::Integer) = linear_setindex!(data, val, I)
1325+
@propagate_inbounds Base.setindex!(data::IF, val, I::Integer) = linear_setindex!(data, val, I)
1326+
@propagate_inbounds Base.setindex!(data::VF, val, I::Integer) = linear_setindex!(data, val, I)
1327+
@propagate_inbounds Base.setindex!(data::VIJFH, val, I::Integer) = linear_setindex!(data, val, I)
1328+
@propagate_inbounds Base.setindex!(data::VIFH, val, I::Integer) = linear_setindex!(data, val, I)
1329+
@propagate_inbounds Base.setindex!(data::IH1JH2, val, I::Integer) = linear_setindex!(data, val, I)
1330+
@propagate_inbounds Base.setindex!(data::IV1JH2, val, I::Integer) = linear_setindex!(data, val, I)
1331+
1332+
@propagate_inbounds Base.getindex(data::IJKFVH, I::Integer) = linear_getindex(data, I)
1333+
@propagate_inbounds Base.getindex(data::IJFH, I::Integer) = linear_getindex(data, I)
1334+
@propagate_inbounds Base.getindex(data::IFH, I::Integer) = linear_getindex(data, I)
1335+
@propagate_inbounds Base.getindex(data::DataF, I::Integer) = linear_getindex(data, I)
1336+
@propagate_inbounds Base.getindex(data::IJF, I::Integer) = linear_getindex(data, I)
1337+
@propagate_inbounds Base.getindex(data::IF, I::Integer) = linear_getindex(data, I)
1338+
@propagate_inbounds Base.getindex(data::VF, I::Integer) = linear_getindex(data, I)
1339+
@propagate_inbounds Base.getindex(data::VIJFH, I::Integer) = linear_getindex(data, I)
1340+
@propagate_inbounds Base.getindex(data::VIFH, I::Integer) = linear_getindex(data, I)
1341+
@propagate_inbounds Base.getindex(data::IH1JH2, I::Integer) = linear_getindex(data, I)
1342+
@propagate_inbounds Base.getindex(data::IV1JH2, I::Integer) = linear_getindex(data, I)
13151343

1316-
# Datalayouts
13171344
@propagate_inbounds function linear_getindex(
13181345
data::AbstractData{S},
13191346
I::Integer,
@@ -1344,7 +1371,6 @@ end
13441371
)
13451372
end
13461373

1347-
13481374
Base.ndims(data::AbstractData) = Base.ndims(typeof(data))
13491375
Base.ndims(::Type{T}) where {T <: AbstractData} =
13501376
Base.ndims(field_array_type(T))

src/DataLayouts/copyto.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ end
4040

4141
# broadcasting scalar assignment
4242
# Performance optimization for the common identity scalar case: dest .= val
43-
# And this is valid for the CPU or GPU, since the broadcasted object
44-
# is a scalar type.
4543
function Base.copyto!(
4644
dest::AbstractData,
4745
bc::Base.Broadcast.Broadcasted{Style},

src/DataLayouts/field_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ function field_array(
395395
FieldArray{FD}(EmptyArray(scalar_array))
396396
end
397397

398+
# Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "$(arrays_type(typeof(fa)))(", Array(fa), ")")
399+
Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "FieldArray{$FD}(", Array(fa), ")")
400+
398401
# Warning: this method is type-unstable.
399402
function Base.view(fa::FieldArray{FD}, inds...) where {FD}
400403
AI = dropat(inds, Val(FD))

test/DataLayouts/data0d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ end
4848
array = zeros(Float64, 3)
4949
data = DataF{S}(array)
5050
@test data[][2] == zero(Float64)
51-
@test_throws MethodError data[1]
51+
# @test_throws MethodError data[1] # this no longer can throw an error
5252
end
5353

5454
@testset "DataF type safety" begin

0 commit comments

Comments
 (0)