Skip to content

Commit f5ceb33

Browse files
vchuravyleiosjpsamaroo
authored
Excise Cassette (#288)
Co-authored-by: James Schloss <jrs.schloss@gmail.com> Co-authored-by: Julian Samaroo <jpsamaroo@jpsamaroo.me>
1 parent fcc4f57 commit f5ceb33

File tree

23 files changed

+82
-466
lines changed

23 files changed

+82
-466
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,6 @@ steps:
1919
KERNELABSTRACTIONS_TEST_BACKEND: "CUDA"
2020
timeout_in_minutes: 60
2121

22-
- label: "CUDA Julia 1.6"
23-
plugins:
24-
- JuliaCI/julia#v1:
25-
version: "1.6"
26-
- JuliaCI/julia-coverage#v1:
27-
codecov: true
28-
dirs:
29-
- src
30-
- lib
31-
commands:
32-
- julia .ci/develop.jl
33-
- julia .ci/test.jl
34-
agents:
35-
queue: "juliagpu"
36-
cuda: "*"
37-
env:
38-
JULIA_CUDA_USE_BINARYBUILDER: "true"
39-
KERNELABSTRACTIONS_TEST_BACKEND: "CUDA"
40-
timeout_in_minutes: 60
41-
4222
- label: "CUDA Julia 1.7"
4323
plugins:
4424
- JuliaCI/julia#v1:
@@ -79,10 +59,10 @@ steps:
7959
KERNELABSTRACTIONS_TEST_BACKEND: "CUDA"
8060
timeout_in_minutes: 60
8161

82-
- label: "ROCm Julia 1.6"
62+
- label: "ROCm Julia 1.7"
8363
plugins:
8464
- JuliaCI/julia#v1:
85-
version: "1.6"
65+
version: "1.7"
8666
- JuliaCI/julia-coverage#v1:
8767
codecov: true
8868
dirs:

.ci/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ end
1212
if !CI || BACKEND == "CUDA"
1313
push!(pkgs, "CUDAKernels")
1414
end
15-
# push!(pkgs, "KernelGradients")
15+
if !CI || haskey(ENV, "TEST_KERNELGRADIENTS")
16+
push!(pkgs, "KernelGradients")
17+
end
1618

1719
Pkg.test(pkgs; coverage = true)

.github/workflows/CompatHelper.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
steps:
99
- uses: julia-actions/setup-julia@latest
1010
with:
11-
version: 1.3
11+
version: 1.6
1212
- name: Pkg.add("CompatHelper")
1313
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
1414
- name: CompatHelper.main()

.github/workflows/ci-ka-cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
version:
25-
- '1.6'
25+
- '1.7'
2626
- '1' # automatically expands to the latest stable 1.x release of Julia.
2727
- 'nightly'
2828
os:

.github/workflows/ci-ka-rocm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
version:
25-
- '1.6'
25+
- '1.7'
2626
os:
2727
- ubuntu-latest
2828
- macOS-latest

.github/workflows/ci-ka.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
version:
25-
- '1.6'
25+
- '1.7'
2626
- '1' # automatically expands to the latest stable 1.x release of Julia.
2727
- 'nightly'
2828
os:
@@ -76,7 +76,7 @@ jobs:
7676
- uses: actions/checkout@v2
7777
- uses: julia-actions/setup-julia@v1
7878
with:
79-
version: 'nightly'
79+
version: '1'
8080
- run: julia .ci/add-general-registry.jl
8181
env:
8282
JULIA_PKG_SERVER: ""

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
name = "KernelAbstractions"
22
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
33
authors = ["Valentin Churavy <v.churavy@gmail.com>"]
4-
version = "0.7.2"
4+
version = "0.8.0-dev"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
98
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
109
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1210
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1311
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1412

1513
[compat]
1614
Adapt = "0.4, 1.0, 2.0, 3.0"
17-
Cassette = "0.3.3"
1815
MacroTools = "0.5"
19-
SpecialFunctions = "0.10, 1.0, 2.0"
2016
StaticArrays = "0.12, 1.0"
21-
julia = "1.6"
17+
julia = "1.7"

lib/CUDAKernels/Project.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
name = "CUDAKernels"
22
uuid = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
33
authors = ["Valentin Churavy <v.churavy@gmail.com>"]
4-
version = "0.3.3"
4+
version = "0.4.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
109
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
11-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1210
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1311

1412
[compat]
1513
Adapt = "3.0"
16-
CUDA = "3.5"
17-
Cassette = "0.3.3"
18-
KernelAbstractions = "0.7"
19-
SpecialFunctions = "0.10, 1.0, 2.0"
14+
CUDA = "3.8.2"
15+
KernelAbstractions = "0.8"
2016
StaticArrays = "0.12, 1.0"
21-
julia = "1.6"
17+
julia = "1.7"

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 14 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
module CUDAKernels
22

33
import CUDA
4-
import SpecialFunctions
54
import StaticArrays
65
import StaticArrays: MArray
7-
import Cassette
86
import Adapt
97
import KernelAbstractions
108

@@ -191,7 +189,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
191189
ndrange, workgroupsize, iterspace, dynamic = launch_config(obj, ndrange, workgroupsize)
192190
# this might not be the final context, since we may tune the workgroupsize
193191
ctx = mkcontext(obj, ndrange, iterspace)
194-
kernel = CUDA.@cuda launch=false name=String(nameof(obj.f)) Cassette.overdub(CUDACTX, obj.f, ctx, args...)
192+
kernel = CUDA.@cuda launch=false obj.f(ctx, args...)
195193

196194
# figure out the optimal workgroupsize automatically
197195
if KernelAbstractions.workgroupsize(obj) <: DynamicSize && workgroupsize === nothing
@@ -220,52 +218,49 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
220218

221219
# Launch kernel
222220
event = CUDA.CuEvent(CUDA.EVENT_DISABLE_TIMING)
223-
kernel(CUDACTX, obj.f, ctx, args...; threads=threads, blocks=nblocks, stream=stream)
221+
kernel(ctx, args...; threads=threads, blocks=nblocks, stream=stream)
224222

225223
CUDA.record(event, stream)
226224
return CudaEvent(event)
227225
end
228226

229-
Cassette.@context CUDACtx
227+
import CUDA: @device_override
230228

231229
import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearIndices
232230
import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
233231
import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
234232

235-
const CUDACTX = Cassette.disablehooks(CUDACtx(pass = CompilerPass))
236-
KernelAbstractions.cassette(::Kernel{CUDADevice}) = CUDACTX
237-
238233
function mkcontext(kernel::Kernel{CUDADevice}, _ndrange, iterspace)
239234
CompilerMetadata{KernelAbstractions.ndrange(kernel), DynamicCheck}(_ndrange, iterspace)
240235
end
241236

242-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Linear), ctx)
237+
@device_override @inline function __index_Local_Linear(ctx)
243238
return CUDA.threadIdx().x
244239
end
245240

246-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Linear), ctx)
241+
@device_override @inline function __index_Group_Linear(ctx)
247242
return CUDA.blockIdx().x
248243
end
249244

250-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Linear), ctx)
245+
@device_override @inline function __index_Global_Linear(ctx)
251246
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
252247
# TODO: This is unfortunate, can we get the linear index cheaper
253248
@inbounds LinearIndices(__ndrange(ctx))[I]
254249
end
255250

256-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Local_Cartesian), ctx)
251+
@device_override @inline function __index_Local_Cartesian(ctx)
257252
@inbounds workitems(__iterspace(ctx))[CUDA.threadIdx().x]
258253
end
259254

260-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Group_Cartesian), ctx)
255+
@device_override @inline function __index_Group_Cartesian(ctx)
261256
@inbounds blocks(__iterspace(ctx))[CUDA.blockIdx().x]
262257
end
263258

264-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__index_Global_Cartesian), ctx)
259+
@device_override @inline function __index_Global_Cartesian(ctx)
265260
return @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
266261
end
267262

268-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__validindex), ctx)
263+
@device_override @inline function __validindex(ctx)
269264
if __dynamic_checkbounds(ctx)
270265
I = @inbounds expand(__iterspace(ctx), CUDA.blockIdx().x, CUDA.threadIdx().x)
271266
return I in __ndrange(ctx)
@@ -276,43 +271,6 @@ end
276271

277272
import KernelAbstractions: groupsize, __groupsize, __workitems_iterspace, add_float_contract, sub_float_contract, mul_float_contract
278273

279-
KernelAbstractions.generate_overdubs(@__MODULE__, CUDACtx)
280-
281-
###
282-
# CUDA specific method rewrites
283-
###
284-
285-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Float64) = ^(x, y)
286-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Float32) = ^(x, y)
287-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float64, y::Int32) = ^(x, y)
288-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Float32, y::Int32) = ^(x, y)
289-
@inline Cassette.overdub(::CUDACtx, ::typeof(^), x::Union{Float32, Float64}, y::Int64) = ^(x, y)
290-
291-
# libdevice.jl
292-
const cudafuns = (:cos, :cospi, :sin, :sinpi, :tan,
293-
:acos, :asin, :atan,
294-
:cosh, :sinh, :tanh,
295-
:acosh, :asinh, :atanh,
296-
:log, :log10, :log1p, :log2,
297-
:exp, :exp2, :exp10, :expm1, :ldexp,
298-
# :isfinite, :isinf, :isnan, :signbit,
299-
:abs,
300-
:sqrt, :cbrt,
301-
:ceil, :floor,)
302-
for f in cudafuns
303-
@eval function Cassette.overdub(ctx::CUDACtx, ::typeof(Base.$f), x::Union{Float32, Float64})
304-
@Base._inline_meta
305-
return Base.$f(x)
306-
end
307-
end
308-
309-
@inline Cassette.overdub(::CUDACtx, ::typeof(sincos), x::Union{Float32, Float64}) = (Base.sin(x), Base.cos(x))
310-
@inline Cassette.overdub(::CUDACtx, ::typeof(exp), x::Union{ComplexF32, ComplexF64}) = Base.exp(x)
311-
312-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.gamma), x::Union{Float32, Float64}) = CUDA.tgamma(x)
313-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = SpecialFunctions.erf(x)
314-
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = SpecialFunctions.erfc(x)
315-
316274
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
317275
const emit_shmem = CUDA.emit_shmem
318276
else
@@ -325,7 +283,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
325283
# GPU implementation of shared memory
326284
###
327285

328-
@inline function Cassette.overdub(::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
286+
@device_override @inline function SharedMemory(::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
329287
ptr = emit_shmem(T, Val(prod(Dims)))
330288
CUDA.CuDeviceArray(Dims, ptr)
331289
end
@@ -335,15 +293,15 @@ end
335293
# - private memory for each workitem
336294
###
337295

338-
@inline function Cassette.overdub(::CUDACtx, ::typeof(Scratchpad), ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
296+
@device_override @inline function Scratchpad(ctx, ::Type{T}, ::Val{Dims}) where {T, Dims}
339297
MArray{__size(Dims), T}(undef)
340298
end
341299

342-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__synchronize))
300+
@device_override @inline function __synchronize()
343301
CUDA.sync_threads()
344302
end
345303

346-
@inline function Cassette.overdub(::CUDACtx, ::typeof(__print), args...)
304+
@device_override @inline function __print(args...)
347305
CUDA._cuprint(args...)
348306
end
349307

@@ -356,29 +314,4 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
356314
# Argument conversion
357315
KernelAbstractions.argconvert(k::Kernel{CUDADevice}, arg) = CUDA.cudaconvert(arg)
358316

359-
# Cassette.jl#195
360-
# Device intrinsics are inferred in a different World (1.6) or using MethodOverlay tables (1.7)
361-
# Cassette sees neither of them and thus overdubbing them fails.
362-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.arrayref), args...)
363-
CUDA.arrayref(args...)
364-
end
365-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.arrayset), args...)
366-
CUDA.arrayset(args...)
367-
end
368-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.const_arrayref), args...)
369-
CUDA.const_arrayref(args...)
370-
end
371-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.logb), args...)
372-
CUDA.logb(args...)
373-
end
374-
# @inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.tgamma), args...)
375-
# CUDA.tgamma(args...)
376-
# end
377-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.compute_capability), args...)
378-
CUDA.compute_capability(args...)
379-
end
380-
@inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.ptx_isa_version), args...)
381-
CUDA.ptx_isa_version(args...)
382-
end
383-
384317
end

lib/KernelGradients/Project.toml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ authors = ["Valentin Churavy <v.churavy@gmail.com>"]
44
version = "0.1.0"
55

66
[deps]
7-
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
87
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
98
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
10-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
11-
129

1310
[compat]
14-
Cassette = "0.3"
15-
KernelAbstractions = "0.7"
16-
Requires = "1.1"
1711
Enzyme = "0.7"
18-
julia = "1.6"
12+
KernelAbstractions = "0.8"
13+
julia = "1.7"

0 commit comments

Comments
 (0)