Skip to content

Commit a254113

Browse files
Fixes
1 parent 778b91f commit a254113

File tree

6 files changed

+89
-28
lines changed

6 files changed

+89
-28
lines changed

ext/cuda/data_layouts.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,38 @@ function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
8888
return DL.FieldArray{FD}(arrays)
8989
end
9090

91-
DL.field_array(
92-
array::CUDA.CuArray,
93-
as::ArraySize
94-
) = CUDA.CuArray(DL.field_array(Array(array), as))
91+
DL.field_array(array::CUDA.CuArray, as::ArraySize) =
92+
CUDA.CuArray(DL.field_array(Array(array), as))
93+
94+
95+
# TODO: this could be improved, but it's not typically used at runtime
96+
function copyto_field_array_knl!(x::DL.FieldArray{FD}, y) where {FD}
97+
gidx =
98+
CUDA.threadIdx().x + (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x
99+
I = cart_ind(size(y), gidx)
100+
x[I] = y[I]
101+
return nothing
102+
end
103+
104+
@inline function Base.copyto!(
105+
x::DL.FieldArray{FD, NT},
106+
y::CUDA.CuArray,
107+
) where {FD, NT <: NTuple}
108+
if ndims(eltype(NT)) == ndims(y)
109+
@inbounds for i in 1:DL.tuple_length(NT)
110+
Base.copyto!(x.arrays[i], y)
111+
end
112+
elseif ndims(eltype(NT)) + 1 == ndims(y)
113+
n = prod(size(y))
114+
kernel =
115+
CUDA.@cuda always_inline = true launch = false copyto_field_array_knl!(
116+
x,
117+
y,
118+
)
119+
config = CUDA.launch_configuration(kernel.fun)
120+
threads = min(n, config.threads)
121+
blocks = cld(n, threads)
122+
kernel(x, y; threads, blocks)
123+
end
124+
x
125+
end

ext/cuda/data_layouts_copyto.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ function knl_copyto_linear!(dest::AbstractData, bc, us)
2929
return nothing
3030
end
3131

32-
function knl_copyto_linear!(dest::DataF, bc, us)
33-
tidx = thread_index()
34-
@inbounds dest[] = bc[tidx]
32+
function knl_copyto_linear!(dest::DataF{S},bc,us) where {S}
33+
@inbounds begin
34+
tidx = thread_index()
35+
if tidx get_N(us)
36+
dest[] = bc[tidx]
37+
end
38+
end
3539
return nothing
3640
end
3741

@@ -47,13 +51,17 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
4751
return nothing
4852
end
4953

50-
function knl_copyto_flat!(dest::DataF, bc, us)
54+
function knl_copyto_flat!(
55+
dest::DataF{S},
56+
bc::DataLayouts.BroadcastedUnionDataF{S},
57+
us,
58+
) where {S}
5159
@inbounds begin
5260
tidx = thread_index()
5361
if tidx get_N(us)
5462
n = size(dest)
55-
I = kernel_indexes(tidx, n)
56-
dest[] = bc[I]
63+
# I = kernel_indexes(tidx, n)
64+
dest[] = bc[]
5765
end
5866
end
5967
return nothing
@@ -66,12 +74,7 @@ function cuda_copyto!(dest::AbstractData, bc)
6674
if Nv > 0 && Nh > 0
6775
if has_uniform_datalayouts(bc)
6876
bc′ = to_non_extruded_broadcasted(bc)
69-
auto_launch!(
70-
knl_copyto_linear!,
71-
(dest, bc′, us),
72-
n;
73-
auto = true,
74-
)
77+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), n; auto = true)
7578
else
7679
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
7780
end

src/DataLayouts/DataLayouts.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ end
564564
@inbounds col[]
565565
end
566566

567+
@propagate_inbounds function Base.getindex(col::Data0D, I::Integer)
568+
@inbounds col[]
569+
end
570+
567571
Base.@propagate_inbounds function Base.setindex!(data::DataF{S}, val) where {S}
568572
@inbounds set_struct!(
569573
field_array(data),
@@ -581,6 +585,14 @@ end
581585
@inbounds col[] = val
582586
end
583587

588+
@propagate_inbounds function Base.setindex!(
589+
col::Data0D,
590+
val,
591+
I::Integer,
592+
)
593+
@inbounds col[] = val
594+
end
595+
584596
# ======================
585597
# DataSlab2D DataLayout
586598
# ======================
@@ -1305,15 +1317,30 @@ type parameters.
13051317

13061318
#! format: on
13071319

1308-
# Skip DataF here, since we want that to MethodError.
1309-
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH)
1310-
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
1311-
linear_getindex(data, I)
1312-
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
1313-
linear_setindex!(data, val, I)
1314-
end
1320+
@propagate_inbounds Base.setindex!(data::IJKFVH, val, I::Integer) = linear_setindex!(data, val, I)
1321+
@propagate_inbounds Base.setindex!(data::IJFH, val, I::Integer) = linear_setindex!(data, val, I)
1322+
@propagate_inbounds Base.setindex!(data::IFH, val, I::Integer) = linear_setindex!(data, val, I)
1323+
@propagate_inbounds Base.setindex!(data::DataF, val, I::Integer) = linear_setindex!(data, val, I)
1324+
@propagate_inbounds Base.setindex!(data::IJF, val, I::Integer) = linear_setindex!(data, val, I)
1325+
@propagate_inbounds Base.setindex!(data::IF, val, I::Integer) = linear_setindex!(data, val, I)
1326+
@propagate_inbounds Base.setindex!(data::VF, val, I::Integer) = linear_setindex!(data, val, I)
1327+
@propagate_inbounds Base.setindex!(data::VIJFH, val, I::Integer) = linear_setindex!(data, val, I)
1328+
@propagate_inbounds Base.setindex!(data::VIFH, val, I::Integer) = linear_setindex!(data, val, I)
1329+
@propagate_inbounds Base.setindex!(data::IH1JH2, val, I::Integer) = linear_setindex!(data, val, I)
1330+
@propagate_inbounds Base.setindex!(data::IV1JH2, val, I::Integer) = linear_setindex!(data, val, I)
1331+
1332+
@propagate_inbounds Base.getindex(data::IJKFVH, I::Integer) = linear_getindex(data, I)
1333+
@propagate_inbounds Base.getindex(data::IJFH, I::Integer) = linear_getindex(data, I)
1334+
@propagate_inbounds Base.getindex(data::IFH, I::Integer) = linear_getindex(data, I)
1335+
@propagate_inbounds Base.getindex(data::DataF, I::Integer) = linear_getindex(data, I)
1336+
@propagate_inbounds Base.getindex(data::IJF, I::Integer) = linear_getindex(data, I)
1337+
@propagate_inbounds Base.getindex(data::IF, I::Integer) = linear_getindex(data, I)
1338+
@propagate_inbounds Base.getindex(data::VF, I::Integer) = linear_getindex(data, I)
1339+
@propagate_inbounds Base.getindex(data::VIJFH, I::Integer) = linear_getindex(data, I)
1340+
@propagate_inbounds Base.getindex(data::VIFH, I::Integer) = linear_getindex(data, I)
1341+
@propagate_inbounds Base.getindex(data::IH1JH2, I::Integer) = linear_getindex(data, I)
1342+
@propagate_inbounds Base.getindex(data::IV1JH2, I::Integer) = linear_getindex(data, I)
13151343

1316-
# Datalayouts
13171344
@propagate_inbounds function linear_getindex(
13181345
data::AbstractData{S},
13191346
I::Integer,
@@ -1344,7 +1371,6 @@ end
13441371
)
13451372
end
13461373

1347-
13481374
Base.ndims(data::AbstractData) = Base.ndims(typeof(data))
13491375
Base.ndims(::Type{T}) where {T <: AbstractData} =
13501376
Base.ndims(field_array_type(T))

src/DataLayouts/copyto.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ end
4040

4141
# broadcasting scalar assignment
4242
# Performance optimization for the common identity scalar case: dest .= val
43-
# And this is valid for the CPU or GPU, since the broadcasted object
44-
# is a scalar type.
4543
function Base.copyto!(
4644
dest::AbstractData,
4745
bc::Base.Broadcast.Broadcasted{Style},

src/DataLayouts/field_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ function field_array(
395395
FieldArray{FD}(EmptyArray(scalar_array))
396396
end
397397

398+
# Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "$(arrays_type(typeof(fa)))(", Array(fa), ")")
399+
Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "FieldArray{$FD}(", Array(fa), ")")
400+
398401
# Warning: this method is type-unstable.
399402
function Base.view(fa::FieldArray{FD}, inds...) where {FD}
400403
AI = dropat(inds, Val(FD))

test/DataLayouts/data0d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ end
4848
array = zeros(Float64, 3)
4949
data = DataF{S}(array)
5050
@test data[][2] == zero(Float64)
51-
@test_throws MethodError data[1]
51+
# @test_throws MethodError data[1] # this no longer can throw an error
5252
end
5353

5454
@testset "DataF type safety" begin

0 commit comments

Comments
 (0)