Skip to content

Commit a433fe3

Browse files
authored
Merge pull request #145 from JuliaGPU/vc/cuda2
use Const from CUDA and support 2.0
2 parents 74df836 + 765ee73 commit a433fe3

File tree

2 files changed

+3
-33
lines changed

2 files changed

+3
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1515

1616
[compat]
1717
Adapt = "0.4, 1.0, 2.0"
18-
CUDA = "~1.0, ~1.1, ~1.2, 1.3"
18+
CUDA = "~1.0, ~1.1, ~1.2, 1.3, 2"
1919
Cassette = "0.3.3"
2020
MacroTools = "0.5"
2121
SpecialFunctions = "0.10"

src/backends/cuda.jl

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -318,37 +318,7 @@ end
318318
end
319319

320320
###
321-
# GPU implementation of `@Const`
321+
# GPU implementation of const memory
322322
###
323-
struct ConstCuDeviceArray{T,N,A} <: AbstractArray{T,N}
324-
shape::Dims{N}
325-
ptr::CUDA.DevicePtr{T,A}
326323

327-
# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
328-
ConstCuDeviceArray{T,N,A}(shape::Dims{N}, ptr::CUDA.DevicePtr{T,A}) where {T,A,N} = new(shape,ptr)
329-
end
330-
331-
Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray{T,N,A}) where {T,N,A} = ConstCuDeviceArray{T, N, A}(a.shape, a.ptr)
332-
333-
Base.pointer(a::ConstCuDeviceArray) = a.ptr
334-
Base.pointer(a::ConstCuDeviceArray, i::Integer) =
335-
pointer(a) + (i - 1) * Base.elsize(a)
336-
337-
Base.elsize(::Type{<:ConstCuDeviceArray{T}}) where {T} = sizeof(T)
338-
Base.size(g::ConstCuDeviceArray) = g.shape
339-
Base.length(g::ConstCuDeviceArray) = prod(g.shape)
340-
Base.IndexStyle(::Type{<:ConstCuDeviceArray}) = Base.IndexLinear()
341-
342-
Base.unsafe_convert(::Type{CUDA.DevicePtr{T,A}}, a::ConstCuDeviceArray{T,N,A}) where {T,A,N} = pointer(a)
343-
344-
@inline function Base.getindex(A::ConstCuDeviceArray{T}, index::Integer) where {T}
345-
@boundscheck checkbounds(A, index)
346-
align = Base.datatype_alignment(T)
347-
CUDA.unsafe_cached_load(pointer(A), index, Val(align))::T
348-
end
349-
350-
@inline function Base.unsafe_view(arr::ConstCuDeviceArray{T, 1, A}, I::Vararg{Base.ViewIndex,1}) where {T, A}
351-
ptr = pointer(arr) + (I[1].start-1)*sizeof(T)
352-
len = I[1].stop - I[1].start + 1
353-
return ConstCuDeviceArray{T,1,A}(len, ptr)
354-
end
324+
Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental.Const(a)

0 commit comments

Comments
 (0)