Skip to content

Commit 9ba577b

Browse files
Fixes
1 parent b54dad5 commit 9ba577b

File tree

6 files changed

+67
-20
lines changed

6 files changed

+67
-20
lines changed

ext/cuda/data_layouts.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,38 @@ function CUDA.CuArray(fa::DL.FieldArray{FD}) where {FD}
8787
return DL.FieldArray{FD}(arrays)
8888
end
8989

90-
DL.field_array(
91-
array::CUDA.CuArray,
92-
as::ArraySize
93-
) = CUDA.CuArray(DL.field_array(Array(array), as))
90+
DL.field_array(array::CUDA.CuArray, as::ArraySize) =
91+
CUDA.CuArray(DL.field_array(Array(array), as))
92+
93+
94+
# TODO: this could be improved, but it's not typically used at runtime
95+
function copyto_field_array_knl!(x::DL.FieldArray{FD}, y) where {FD}
96+
gidx =
97+
CUDA.threadIdx().x + (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x
98+
I = cart_ind(size(y), gidx)
99+
x[I] = y[I]
100+
return nothing
101+
end
102+
103+
@inline function Base.copyto!(
104+
x::DL.FieldArray{FD, NT},
105+
y::CUDA.CuArray,
106+
) where {FD, NT <: NTuple}
107+
if ndims(eltype(NT)) == ndims(y)
108+
@inbounds for i in 1:DL.tuple_length(NT)
109+
Base.copyto!(x.arrays[i], y)
110+
end
111+
elseif ndims(eltype(NT)) + 1 == ndims(y)
112+
n = prod(size(y))
113+
kernel =
114+
CUDA.@cuda always_inline = true launch = false copyto_field_array_knl!(
115+
x,
116+
y,
117+
)
118+
config = CUDA.launch_configuration(kernel.fun)
119+
threads = min(n, config.threads)
120+
blocks = cld(n, threads)
121+
kernel(x, y; threads, blocks)
122+
end
123+
x
124+
end

ext/cuda/data_layouts_copyto.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ function knl_copyto_linear!(dest::AbstractData, bc, us)
3030
return nothing
3131
end
3232

33-
function knl_copyto_linear!(dest::DataF, bc, us)
34-
tidx = thread_index()
35-
@inbounds dest[] = bc[tidx]
33+
function knl_copyto_linear!(dest::DataF{S},bc,us) where {S}
34+
@inbounds begin
35+
tidx = thread_index()
36+
if tidx get_N(us)
37+
dest[] = bc[tidx]
38+
end
39+
end
3640
return nothing
3741
end
3842

@@ -48,13 +52,17 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
4852
return nothing
4953
end
5054

51-
function knl_copyto_flat!(dest::DataF, bc, us)
55+
function knl_copyto_flat!(
56+
dest::DataF{S},
57+
bc::DataLayouts.BroadcastedUnionDataF{S},
58+
us,
59+
) where {S}
5260
@inbounds begin
5361
tidx = thread_index()
5462
if tidx get_N(us)
5563
n = size(dest)
56-
I = kernel_indexes(tidx, n)
57-
dest[] = bc[I]
64+
# I = kernel_indexes(tidx, n)
65+
dest[] = bc[]
5866
end
5967
end
6068
return nothing
@@ -67,12 +75,7 @@ function cuda_copyto!(dest::AbstractData, bc)
6775
if Nv > 0 && Nh > 0
6876
if has_uniform_datalayouts(bc)
6977
bc′ = to_non_extruded_broadcasted(bc)
70-
auto_launch!(
71-
knl_copyto_linear!,
72-
(dest, bc′, us),
73-
n;
74-
auto = true,
75-
)
78+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), n; auto = true)
7679
else
7780
auto_launch!(knl_copyto_flat!, (dest, bc, us), n; auto = true)
7881
end

src/DataLayouts/DataLayouts.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ end
564564
@inbounds col[]
565565
end
566566

567+
@propagate_inbounds function Base.getindex(col::Data0D, I::Integer)
568+
@inbounds col[]
569+
end
570+
567571
Base.@propagate_inbounds function Base.setindex!(data::DataF{S}, val) where {S}
568572
@inbounds set_struct!(
569573
field_array(data),
@@ -581,6 +585,14 @@ end
581585
@inbounds col[] = val
582586
end
583587

588+
@propagate_inbounds function Base.setindex!(
589+
col::Data0D,
590+
val,
591+
I::Integer,
592+
)
593+
@inbounds col[] = val
594+
end
595+
584596
# ======================
585597
# DataSlab2D DataLayout
586598
# ======================
@@ -1306,7 +1318,7 @@ type parameters.
13061318
#! format: on
13071319

13081320
# Skip DataF here, since we want that to MethodError.
1309-
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH)
1321+
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :DataF, :VF, :VIJFH, :VIFH)
13101322
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
13111323
linear_getindex(data, I)
13121324
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =

src/DataLayouts/copyto.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ end
4040

4141
# broadcasting scalar assignment
4242
# Performance optimization for the common identity scalar case: dest .= val
43-
# And this is valid for the CPU or GPU, since the broadcasted object
44-
# is a scalar type.
4543
function Base.copyto!(
4644
dest::AbstractData,
4745
bc::Base.Broadcast.Broadcasted{Style},

src/DataLayouts/field_array.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,9 @@ function field_array(
395395
FieldArray{FD}(EmptyArray(scalar_array))
396396
end
397397

398+
# Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "$(arrays_type(typeof(fa)))(", Array(fa), ")")
399+
Base.show(io::IO, fa::FieldArray{FD}) where {FD} = print(io, "FieldArray{$FD}(", Array(fa), ")")
400+
398401
# Warning: this method is type-unstable.
399402
function Base.view(fa::FieldArray{FD}, inds...) where {FD}
400403
AI = dropat(inds, Val(FD))

test/DataLayouts/data0d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ end
4848
array = zeros(Float64, 3)
4949
data = DataF{S}(array)
5050
@test data[][2] == zero(Float64)
51-
@test_throws MethodError data[1]
51+
# @test_throws MethodError data[1] # this no longer can throw an error
5252
end
5353

5454
@testset "DataF type safety" begin

0 commit comments

Comments
 (0)