Skip to content

Commit 184b36f

Browse files
committed
Minimize diff.
1 parent 472db83 commit 184b36f

File tree

9 files changed

+117
-118
lines changed

9 files changed

+117
-118
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
# reference implementation on the CPU
22
# This acts as a wrapper around KernelAbstractions's parallel CPU
3-
# functionality. It is useful for testing GPUArrays (and other packages)
3+
# functionality. It is useful for testing GPUArrays (and other packages)
44
# when no GPU is present.
55
# This file follows conventions from AMDGPU.jl
66

77
module JLArrays
88

9+
export JLArray, JLVector, JLMatrix, jl, JLBackend
10+
911
using GPUArrays
12+
1013
using Adapt
14+
1115
import KernelAbstractions
1216
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
1317

14-
export JLArray, JLVector, JLMatrix, jl, JLBackend
1518

1619
#
1720
# Device functionality
@@ -24,7 +27,6 @@ struct JLBackend <: KernelAbstractions.GPU
2427
JLBackend(;static::Bool=false) = new(static)
2528
end
2629

27-
2830
struct Adaptor end
2931
jlconvert(arg) = adapt(Adaptor(), arg)
3032

@@ -35,37 +37,7 @@ end
3537
Base.getindex(r::JlRefValue) = r.x
3638
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[]))
3739

38-
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
39-
data::DataRef{Vector{UInt8}}
40-
41-
offset::Int # offset of the data in the buffer, in number of elements
42-
43-
dims::Dims{N}
44-
45-
# allocating constructor
46-
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
47-
check_eltype(T)
48-
maxsize = prod(dims) * sizeof(T)
49-
data = Vector{UInt8}(undef, maxsize)
50-
ref = DataRef(data) do data
51-
resize!(data, 0)
52-
end
53-
obj = new{T,N}(ref, 0, dims)
54-
finalizer(unsafe_free!, obj)
55-
end
56-
57-
# low-level constructor for wrapping existing data
58-
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
59-
offset::Int=0) where {T,N}
60-
check_eltype(T)
61-
obj = new{T,N}(ref, offset, dims)
62-
finalizer(unsafe_free!, obj)
63-
end
64-
end
65-
66-
Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
67-
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
68-
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
40+
## executed on-device
6941

7042
# array type
7143

@@ -91,6 +63,7 @@ end
9163
@inline Base.getindex(A::JLDeviceArray, index::Integer) = getindex(typed_data(A), index)
9264
@inline Base.setindex!(A::JLDeviceArray, x, index::Integer) = setindex!(typed_data(A), x, index)
9365

66+
9467
#
9568
# Host abstractions
9669
#
@@ -104,6 +77,34 @@ function check_eltype(T)
10477
end
10578
end
10679

80+
mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
81+
data::DataRef{Vector{UInt8}}
82+
83+
offset::Int # offset of the data in the buffer, in number of elements
84+
85+
dims::Dims{N}
86+
87+
# allocating constructor
88+
function JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
89+
check_eltype(T)
90+
maxsize = prod(dims) * sizeof(T)
91+
data = Vector{UInt8}(undef, maxsize)
92+
ref = DataRef(data) do data
93+
resize!(data, 0)
94+
end
95+
obj = new{T,N}(ref, 0, dims)
96+
finalizer(unsafe_free!, obj)
97+
end
98+
99+
# low-level constructor for wrapping existing data
100+
function JLArray{T,N}(ref::DataRef{Vector{UInt8}}, dims::Dims{N};
101+
offset::Int=0) where {T,N}
102+
check_eltype(T)
103+
obj = new{T,N}(ref, offset, dims)
104+
finalizer(unsafe_free!, obj)
105+
end
106+
end
107+
107108
unsafe_free!(a::JLArray) = GPUArrays.unsafe_free!(a.data)
108109

109110
# conversion of untyped data to a typed Array
@@ -380,7 +381,10 @@ function (obj::Kernel{JLBackend})(args...; ndrange=nothing, workgroupsize=nothin
380381
device_args = jlconvert.(args)
381382
new_obj = convert_to_cpu(obj)
382383
new_obj(device_args...; ndrange, workgroupsize)
383-
384384
end
385385

386+
Adapt.adapt_storage(::JLBackend, a::Array) = Adapt.adapt(JLArrays.JLArray, a)
387+
Adapt.adapt_storage(::JLBackend, a::JLArrays.JLArray) = a
388+
Adapt.adapt_storage(::KernelAbstractions.CPU, a::JLArrays.JLArray) = convert(Array, a)
389+
386390
end

src/GPUArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ using LLVM.Interop
1515
using Reexport
1616
@reexport using GPUArraysCore
1717

18-
## executed on-device
18+
using KernelAbstractions
19+
20+
# device functionality
1921
include("device/abstractarray.jl")
2022

21-
using KernelAbstractions
2223
# host abstractions
2324
include("host/abstractarray.jl")
2425
include("host/construction.jl")

src/host/abstractarray.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ for (D, S) in ((AnyGPUArray, Array),
159159
end
160160

161161
# kernel-based variant for copying between wrapped GPU arrays
162-
# TODO: Add `@Const` to `src`
163162
@kernel function linear_copy_kernel!(dest, dstart, src, sstart, n)
164163
i = @index(Global, Linear)
165164
if i <= n

src/host/broadcast.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ end
5959
@inbounds dest[I] = bc[I]
6060
end
6161

62-
# grid-stride kernel, ndrange set for possible 0D evaluation
63-
if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) &&
62+
broadcast_kernel = if ndims(dest) == 1 ||
63+
(isa(IndexStyle(dest), IndexLinear) &&
6464
isa(IndexStyle(bc), IndexLinear))
65-
broadcast_kernel_linear(get_backend(dest))(dest, bc;
66-
ndrange = length(size(dest)) > 0 ? length(dest) : 1)
65+
broadcast_kernel_linear(get_backend(dest))
6766
else
68-
broadcast_kernel_cartesian(get_backend(dest))(dest, bc;
69-
ndrange = sz = length(size(dest)) > 0 ? size(dest) : (1,))
67+
broadcast_kernel_cartesian(get_backend(dest))
7068
end
7169

70+
# ndims check for 0D support
71+
broadcast_kernel(dest, bc; ndrange = ndims(dest) > 0 ? size(dest) : (1,))
7272
if eltype(dest) <: BrokenBroadcast
7373
throw(ArgumentError("Broadcast operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
7474
end

src/host/construction.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ Base.convert(::Type{T}, a::AbstractArray) where {T<:AbstractGPUArray} = a isa T
1111

1212
function Base.fill!(A::AnyGPUArray{T}, x) where T
1313
isempty(A) && return A
14+
1415
@kernel function fill_kernel!(a, val)
1516
idx = @index(Global, Linear)
1617
@inbounds a[idx] = val
1718
end
1819

19-
# ndrange set for a possible 0D evaluation
20-
fill_kernel!(get_backend(A))(A, x,
21-
ndrange = length(size(A)) > 0 ? size(A) : (1,))
20+
# ndims check for 0D support
21+
kernel = fill_kernel!(get_backend(A))
22+
kernel(A, x; ndrange = ndims(A) > 0 ? size(A) : (1,))
2223
A
2324
end
2425

src/host/indexing.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,11 @@ end
8282
return vectorized_getindex!(dest, src, Is...)
8383
end
8484

85-
@kernel function getindex_kernel(dest, src, idims,
86-
Is::Vararg{Any,N}) where {N}
85+
@kernel function getindex_kernel(dest, src, idims, Is...)
8786
i = @index(Global, Linear)
8887
getindex_generated(dest, src, idims, i, Is...)
8988
end
90-
91-
@generated function getindex_generated(dest, src, idims, i,
92-
Is::Vararg{Any,N}) where {N}
89+
@generated function getindex_generated(dest, src, idims, i, Is::Vararg{Any,N}) where {N}
9390
quote
9491
is = @inbounds CartesianIndices(idims)[i]
9592
@nexprs $N i -> I_i = @inbounds(Is[i][is[i]])
@@ -120,13 +117,11 @@ end
120117
return dest
121118
end
122119

123-
@kernel function setindex_kernel(dest, src, idims, len,
124-
Is::Vararg{Any,N}) where {N}
120+
@kernel function setindex_kernel(dest, src, idims, len, Is...)
125121
i = @index(Global, Linear)
126122
setindex_generated(dest, src, idims, len, i, Is...)
127123
end
128-
@generated function setindex_generated(dest, src, idims, len, i,
129-
Is::Vararg{Any,N}) where {N}
124+
@generated function setindex_generated(dest, src, idims, len, i, Is::Vararg{Any,N}) where {N}
130125
quote
131126
i > len && return
132127
is = @inbounds CartesianIndices(idims)[i]

0 commit comments

Comments
 (0)