Skip to content

Commit dbc767b

Browse files
authored
Fixes for and tests using JET. (#1577)
1 parent 63e07c0 commit dbc767b

File tree

15 files changed

+80
-84
lines changed

15 files changed

+80
-84
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ steps:
8383
plugins:
8484
- JuliaCI/julia#v1:
8585
version: 1.6
86-
- JuliaCI/julia-test#v1: ~
86+
- JuliaCI/julia-test#v1:
87+
test_args: "--thorough"
8788
- JuliaCI/julia-coverage#v1:
8889
codecov: true
8990
dirs:

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2424
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2525
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2626
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
27-
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2827

2928
[compat]
3029
AbstractFFTs = "0.4, 0.5, 1.0"
@@ -40,5 +39,4 @@ RandomNumbers = "1.5.3"
4039
Reexport = "0.2, 1.0"
4140
Requires = "0.5, 1.0"
4241
SpecialFunctions = "1.3, 2"
43-
TimerOutputs = "0.5.9"
4442
julia = "1.6"

deps/Deps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Deps
22

3-
Base.Experimental.@compiler_options compile=min optimize=0 infer=false
3+
Base.Experimental.@compiler_options compile=min optimize=0
44

55
import ..CUDA
66
import ..LLVM

lib/cudadrv/CUDAdrv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ include("libcuda_deprecated.jl")
4343

4444
# query the system driver version
4545
function get_version(driver)
46-
library_handle = Libdl.dlopen(driver)
46+
library_handle = Libdl.dlopen(driver; throw_error=true)
47+
library_handle = library_handle::Ptr{Nothing} # doesn't const-propthrow_error
4748
try
4849
function_handle = Libdl.dlsym(library_handle, "cuDriverGetVersion")
4950
version_ref = Ref{Cint}()
@@ -82,7 +83,8 @@ include("libcuda_deprecated.jl")
8283

8384
# if we're using an older driver; consider using forward compatibility
8485
function do_init(driver)
85-
library_handle = Libdl.dlopen(driver)
86+
library_handle = Libdl.dlopen(driver; throw_error=true)
87+
library_handle = library_handle::Ptr{Nothing} # doesn't const-propthrow_error
8688
try
8789
function_handle = Libdl.dlsym(library_handle, "cuInit")
8890
@check ccall(function_handle, CUresult, (UInt32,), 0)

lib/cudadrv/state.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ mutable struct TaskLocalState
5050
math_mode = something(default_math_mode[],
5151
Base.JLOptions().fast_math==1 ? FAST_MATH : DEFAULT_MATH)
5252
math_precision = something(default_math_precision[], :TensorFloat32)
53-
new(dev, ctx, Base.fill(nothing, ndevices()), math_mode, math_precision)
53+
new(dev, ctx, Union{Nothing,CuStream}[nothing for _ in 1:ndevices()],
54+
math_mode, math_precision)
5455
end
5556
end
5657

@@ -159,7 +160,7 @@ end
159160
@inline function context!(f::Function, ctx::CuContext; skip_destroyed::Bool=false)
160161
# @inline so that the kwarg method is inlined too and we can const-prop skip_destroyed
161162
if isvalid(ctx)
162-
old_ctx = context!(ctx)
163+
old_ctx = context!(ctx)::Union{CuContext,Nothing}
163164
try
164165
f()
165166
finally
@@ -187,7 +188,7 @@ end
187188

188189
const __device_contexts = LazyInitialized{Vector{Union{Nothing,CuContext}}}()
189190
device_contexts() = get!(__device_contexts) do
190-
[nothing for _ in 1:ndevices()]
191+
Union{Nothing,CuContext}[nothing for _ in 1:ndevices()]
191192
end
192193
function device_context(i::Int)
193194
contexts = device_contexts()
@@ -419,8 +420,8 @@ function PerDevice{T}() where {T}
419420
PerDevice{T}(ReentrantLock(), values)
420421
end
421422

422-
get_values(x::PerDevice) = get!(x.values) do
423-
Base.fill(nothing, ndevices())
423+
get_values(x::PerDevice{T}) where {T} = get!(x.values) do
424+
Union{Nothing,Tuple{CuContext,T}}[nothing for _ in 1:ndevices()]
424425
end
425426

426427
function Base.get(x::PerDevice, dev::CuDevice, val)
@@ -437,20 +438,20 @@ function Base.get(x::PerDevice, dev::CuDevice, val)
437438
end
438439
end
439440

440-
function Base.get!(constructor::F, x::PerDevice, dev::CuDevice) where {F}
441+
function Base.get!(constructor::F, x::PerDevice{T}, dev::CuDevice) where {F, T}
441442
y = get_values(x)
442443
id = deviceid(dev)+1
443444
ctx = device_context(id) # may be nothing
444445
@inbounds begin
445446
# test-lock-test
446-
if y[id] === nothing || y[id][1] !== ctx
447+
if y[id] === nothing || (y[id]::Tuple)[1] !== ctx
447448
Base.@lock x.lock begin
448-
if y[id] === nothing || y[id][1] !== ctx
449+
if y[id] === nothing || (y[id]::Tuple)[1] !== ctx
449450
y[id] = (context(), constructor())
450451
end
451452
end
452453
end
453-
y[id][2]
454+
(y[id]::Tuple)[2]
454455
end
455456
end
456457

lib/cudadrv/stream.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Create a CUDA stream.
1111
"""
1212
mutable struct CuStream
1313
handle::CUstream
14-
ctx::Union{CuContext,Nothing}
14+
ctx::CuContext
1515

1616
function CuStream(; flags::CUstream_flags=STREAM_DEFAULT,
1717
priority::Union{Nothing,Integer}=nothing)
@@ -29,11 +29,11 @@ mutable struct CuStream
2929
return obj
3030
end
3131

32-
global default_stream() = new(convert(CUstream, C_NULL), nothing)
32+
global default_stream() = new(convert(CUstream, C_NULL))
3333

34-
global legacy_stream() = new(convert(CUstream, 1), nothing)
34+
global legacy_stream() = new(convert(CUstream, 1))
3535

36-
global per_thread_stream() = new(convert(CUstream, 2), nothing)
36+
global per_thread_stream() = new(convert(CUstream, 2))
3737
end
3838

3939
"""

lib/utils/memoization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ macro memoize(ex...)
4747
global_cache_eltyp = :(Union{Nothing,$rettyp})
4848
ex = quote
4949
cache = get!($(esc(global_cache))) do
50-
[nothing for _ in 1:Threads.nthreads()]
50+
$global_cache_eltyp[nothing for _ in 1:Threads.nthreads()]
5151
end
5252
cached_value = @inbounds cache[Threads.threadid()]
5353
if cached_value !== nothing
@@ -64,7 +64,7 @@ macro memoize(ex...)
6464
global_init = :(Union{Nothing,$rettyp}[nothing for _ in 1:$(esc(options[:maxlen]))])
6565
ex = quote
6666
cache = get!($(esc(global_cache))) do
67-
[$global_init for _ in 1:Threads.nthreads()]
67+
$global_cache_eltyp[$global_init for _ in 1:Threads.nthreads()]
6868
end
6969
local_cache = @inbounds begin
7070
tid = Threads.threadid()
@@ -86,7 +86,7 @@ macro memoize(ex...)
8686
global_init = :(Dict{$(key.typ),$rettyp}())
8787
ex = quote
8888
cache = get!($(esc(global_cache))) do
89-
[$global_init for _ in 1:Threads.nthreads()]
89+
$global_cache_eltyp[$global_init for _ in 1:Threads.nthreads()]
9090
end
9191
local_cache = @inbounds begin
9292
tid = Threads.threadid()

lib/utils/threading.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ end
2929
x.value[]
3030
end
3131

32-
@noinline function initialize!(x::LazyInitialized, constructor::F1, hook::F2) where {F1, F2}
32+
@noinline function initialize!(x::LazyInitialized{T}, constructor::F1, hook::F2) where {T, F1, F2}
3333
status = Threads.atomic_cas!(x.guard, 0, 1)
3434
if status == 0
3535
try
36-
x.value[] = constructor()
36+
x.value[] = constructor()::T
3737
x.guard[] = 2
3838
catch
3939
x.guard[] = 0

src/CUDA.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,6 @@ using ExprTools: splitdef, combinedef
2424
include("../deps/Deps.jl")
2525
using .Deps
2626

27-
# only use TimerOutputs on non latency-critical CI, in part because
28-
# @timeit_debug isn't truely zero-cost (KristofferC/TimerOutputs.jl#120)
29-
if getenv("CI", false) && !getenv("BENCHMARKS", false)
30-
using TimerOutputs
31-
const to = TimerOutput()
32-
33-
macro timeit_ci(args...)
34-
TimerOutputs.timer_expr(CUDA, false, :($CUDA.to), args...)
35-
end
36-
else
37-
macro timeit_ci(args...)
38-
esc(args[end])
39-
end
40-
end
41-
4227

4328
## source code includes
4429

src/array.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ function unsafe_free!(xs::CuArray, stream::CuStream=stream())
6969
# this call should only have an effect once, because both the user and the GC can call it
7070
if xs.storage === nothing
7171
return
72-
elseif xs.storage.refcount[] < 0
72+
elseif (xs.storage::ArrayStorage).refcount[] < 0
7373
throw(ArgumentError("Cannot free an unmanaged buffer."))
7474
end
7575

76-
refcount = Threads.atomic_add!(xs.storage.refcount, -1)
76+
refcount = Threads.atomic_add!((xs.storage::ArrayStorage).refcount, -1)
7777
if refcount == 1
7878
context!(context(xs); skip_destroyed=true) do
79-
free(xs.storage.buffer; stream)
79+
free((xs.storage::ArrayStorage).buffer; stream)
8080
end
8181
end
8282

@@ -236,7 +236,7 @@ Base.sizeof(x::CuArray) = Base.elsize(x) * length(x)
236236

237237
function context(A::CuArray)
238238
A.storage === nothing && throw(UndefRefError())
239-
return A.storage.buffer.ctx
239+
return (A.storage::ArrayStorage).buffer.ctx
240240
end
241241

242242
function device(A::CuArray)
@@ -318,8 +318,10 @@ Base.convert(::Type{T}, x::T) where T <: CuArray = x
318318

319319
Base.unsafe_convert(::Type{Ptr{T}}, x::CuArray{T}) where {T} =
320320
throw(ArgumentError("cannot take the CPU address of a $(typeof(x))"))
321-
Base.unsafe_convert(::Type{CuPtr{T}}, x::CuArray{T}) where {T} =
322-
convert(CuPtr{T}, x.storage.buffer) + x.offset*Base.elsize(x)
321+
function Base.unsafe_convert(::Type{CuPtr{T}}, x::CuArray{T}) where {T}
322+
x.storage === nothing && throw(UndefRefError())
323+
convert(CuPtr{T}, (x.storage::ArrayStorage).buffer) + x.offset*Base.elsize(x)
324+
end
323325

324326

325327
## interop with device arrays

0 commit comments

Comments
 (0)