Skip to content

Commit 7b2dc94

Browse files
authored
Merge pull request #16 from JuliaGPU/vc/const
implement Const memory for GPU and CPU
2 parents 4a1f0b2 + ce40c7a commit 7b2dc94

File tree

5 files changed

+85
-10
lines changed

5 files changed

+85
-10
lines changed

docs/src/kernels.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# Writing kernels
22

33
These kernel language constructs are intended to be used as part
4-
of [`@kernel`](@ref) functions and not outside that context.
4+
of [`@kernel`](@ref) functions and not valid outside that context.
55

66
## Constant arguments
77

8-
[`@Const`](@ref)
8+
Kernel functions allow for input arguments to be marked with the
9+
[`@Const`](@ref) macro. It informs the compiler that the memory
10+
accessed through that marked input argument, will not be written
11+
to as part of the kernel. This has the implication that input arguments
12+
are **not** allowed to alias each other. If you are used to CUDA C this
13+
is similar to `const restrict`.
914

1015
## Indexing
1116

src/KernelAbstractions.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ export Device, GPU, CPU, CUDA
77
using StaticArrays
88
using Cassette
99
using Requires
10+
using Adapt
1011

1112
"""
1213
@kernel function f(args) end
@@ -147,6 +148,13 @@ function __index_Global_Linear end
147148

148149
function __index_Local_Cartesian end
149150
function __index_Global_Cartesian end
151+
152+
struct ConstAdaptor end
153+
154+
Adapt.adapt_storage(to::ConstAdaptor, a::Array) = Base.Experimental.Const(a)
155+
156+
constify(arg) = adapt(ConstAdaptor(), arg)
157+
150158
###
151159
# Backend hierarchy
152160
###
@@ -271,8 +279,6 @@ end
271279

272280
function __validindex end
273281

274-
# TODO: GPU ConstWrapper that forwards loads to `ldg` and forbids stores
275-
ConstWrapper(A) = A
276282
include("macros.jl")
277283

278284
###

src/backends/cuda.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import CUDAnative, CUDAdrv
2-
import CUDAnative: cufunction
2+
import CUDAnative: cufunction, DevicePtr
33
import CUDAdrv: CuEvent, CuStream, CuDefaultStream
44

55
const FREE_STREAMS = CuStream[]
@@ -218,3 +218,32 @@ end
218218
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__synchronize))
219219
CUDAnative.sync_threads()
220220
end
221+
222+
###
223+
# GPU implementation of `@Const`
224+
###
225+
struct ConstCuDeviceArray{T,N,A} <: AbstractArray{T,N}
226+
shape::Dims{N}
227+
ptr::DevicePtr{T,A}
228+
229+
# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
230+
ConstCuDeviceArray{T,N,A}(shape::Dims{N}, ptr::DevicePtr{T,A}) where {T,A,N} = new(shape,ptr)
231+
end
232+
233+
Adapt.adapt_storage(to::ConstAdaptor, a::CUDAnative.CuDeviceArray{T,N,A}) where {T,N,A} = ConstCuDeviceArray{T, N, A}(a.shape, a.ptr)
234+
235+
Base.pointer(a::ConstCuDeviceArray) = a.ptr
236+
Base.pointer(a::ConstCuDeviceArray, i::Integer) =
237+
pointer(a) + (i - 1) * Base.elsize(a)
238+
239+
Base.elsize(::Type{<:ConstCuDeviceArray{T}}) where {T} = sizeof(T)
240+
Base.size(g::ConstCuDeviceArray) = g.shape
241+
Base.length(g::ConstCuDeviceArray) = prod(g.shape)
242+
243+
Base.unsafe_convert(::Type{DevicePtr{T,A}}, a::ConstCuDeviceArray{T,N,A}) where {T,A,N} = pointer(a)
244+
245+
@inline function Base.getindex(A::ConstCuDeviceArray{T}, index::Integer) where {T}
246+
@boundscheck checkbounds(A, index)
247+
align = Base.datatype_alignment(T)
248+
CUDAnative.unsafe_cached_load(pointer(A), index, Val(align))::T
249+
end

src/macros.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function transform_gpu(expr, args)
6666
new_stmts = Expr[]
6767
for (arg, isconst) in args
6868
if isconst
69-
push!(new_stmts, :($arg = $ConstWrapper($arg)))
69+
push!(new_stmts, :($arg = $constify($arg)))
7070
end
7171
end
7272
return quote
@@ -148,16 +148,15 @@ function transform_cpu(stmts, args)
148148
new_stmts = Expr[]
149149
for (arg, isconst) in args
150150
if isconst
151-
# XXX: Deal with OffsetArrays
152-
push!(new_stmts, :($arg = $Base.Experimental.Const($arg)))
151+
push!(new_stmts, :($arg = $constify($arg)))
153152
end
154153
end
155154
loops = split(stmts)
156155
body = generate_cpu_code(loops)
157156

158-
# push!(new_stmts, Expr(:aliasscope))
157+
push!(new_stmts, Expr(:aliasscope))
159158
push!(new_stmts, body)
160-
# push!(new_stmts, Expr(:popaliasscope))
159+
push!(new_stmts, Expr(:popaliasscope))
161160
push!(new_stmts, :(return nothing))
162161
return Expr(:block, new_stmts...)
163162
end

test/test.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using KernelAbstractions
22
using CUDAapi
3+
using InteractiveUtils
34
if has_cuda_gpu()
45
using CuArrays
6+
using CUDAnative
57
CuArrays.allowscalar(false)
68
end
79

@@ -131,4 +133,38 @@ end
131133
if has_cuda_gpu()
132134
indextest(CUDA(), CuArray)
133135
end
136+
end
137+
138+
@kernel function constarg(A, @Const(B))
139+
I = @index(Global)
140+
@inbounds A[I] = B[I]
141+
end
142+
143+
@testset "Const" begin
144+
let kernel = constarg(CPU(), 8, (1024,))
145+
# this is poking at internals
146+
ctx = KernelAbstractions.mkcontext(kernel, 1, nothing, nothing)
147+
AT = Array{Float32, 2}
148+
IR = sprint() do io
149+
code_llvm(io, KernelAbstractions.Cassette.overdub,
150+
(typeof(ctx), typeof(kernel.f), AT, AT),
151+
optimize=false, raw=true)
152+
end
153+
@test occursin("!alias.scope", IR)
154+
@test occursin("!noalias", IR)
155+
end
156+
157+
if has_cuda_gpu()
158+
let kernel = constarg(CUDA(), 8, (1024,))
159+
# this is poking at internals
160+
ctx = KernelAbstractions.mkcontext(kernel, nothing)
161+
AT = CUDAnative.CuDeviceArray{Float32, 2, CUDAnative.AS.Global}
162+
IR = sprint() do io
163+
CUDAnative.code_llvm(io, KernelAbstractions.Cassette.overdub,
164+
(typeof(ctx), typeof(kernel.f), AT, AT),
165+
kernel=true, optimize=false)
166+
end
167+
@test occursin("@llvm.nvvm.ldg", IR)
168+
end
169+
end
134170
end

0 commit comments

Comments
 (0)