Skip to content

Commit 8c833a8

Browse files
committed
Use stacked method tables
1 parent cacd3c2 commit 8c833a8

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,7 @@ oneAPI_Support_jll = "0.8"
4949

5050
[extras]
5151
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
52+
53+
[sources]
54+
GPUCompiler = {url="https://github.com/JuliaGPU/GPUCompiler.jl", rev="vc/mtv"}
55+
SPIRVIntrinsics = {url="https://github/JuliaGPU/OpenCL.jl", rev="vc/mtv", subdir="lib/intrinsics"}

src/compiler/compilation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const oneAPICompilerJob = CompilerJob{SPIRVCompilerTarget,oneAPICompilerParams}
66

77
GPUCompiler.runtime_module(::oneAPICompilerJob) = oneAPI
88

9-
GPUCompiler.method_table(::oneAPICompilerJob) = method_table
9+
GPUCompiler.method_table_view(job::oneAPICompilerJob) = GPUCompiler.StackedMethodTable(job.world, method_table, spirv_method_table)
1010

1111
# filter out OpenCL built-ins
1212
# TODO: eagerly lower these using the translator API

src/oneAPI.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,40 @@ include("../lib/level-zero/oneL0.jl")
2525
using .oneL0
2626
functional() = oneL0.functional[]
2727

28+
29+
## device overrides
30+
31+
# local method table for device functions
32+
Base.Experimental.@MethodTable(method_table)
33+
34+
macro device_override(ex)
35+
esc(quote
36+
Base.Experimental.@overlay($method_table, $ex)
37+
end)
38+
end
39+
40+
macro device_function(ex)
41+
ex = macroexpand(__module__, ex)
42+
def = ExprTools.splitdef(ex)
43+
44+
# generate a function that errors
45+
def[:body] = quote
46+
error("This function is not intended for use on the CPU")
47+
end
48+
49+
esc(quote
50+
$(ExprTools.combinedef(def))
51+
@device_override $ex
52+
end)
53+
end
54+
2855
# device functionality
2956
import SPIRVIntrinsics
3057
SPIRVIntrinsics.@import_all
3158
SPIRVIntrinsics.@reexport_public
59+
60+
const spirv_method_table = SPIRVIntrinsics.method_table
61+
3262
include("device/runtime.jl")
3363
include("device/array.jl")
3464
include("device/quirks.jl")

0 commit comments

Comments
 (0)