Skip to content

Switch to the LLVM SPIR-V back-end. #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
SPIRV_LLVM_Translator_jll = "4a5d46fc-d8cf-5151-a261-86b458210efb"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
Adapt = "4"
GPUArrays = "11.2.1"
GPUCompiler = "0.27, 1"
GPUCompiler = "1.2"
KernelAbstractions = "0.9.2"
LLVM = "9.1"
LinearAlgebra = "1"
OpenCL_jll = "=2024.10.24"
Printf = "1"
Random = "1"
Reexport = "1"
SPIRVIntrinsics = "0.2"
SPIRV_LLVM_Translator_jll = "20"
SPIRVIntrinsics = "0.3"
SPIRV_LLVM_Backend_jll = "20"
SPIRV_Tools_jll = "2024.4"
StaticArrays = "1"
julia = "1.10"
2 changes: 1 addition & 1 deletion lib/intrinsics/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SPIRVIntrinsics"
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
authors = ["Tim Besard <tim.besard@gmail.com>"]
version = "0.2.1"
version = "0.3.0"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
2 changes: 1 addition & 1 deletion lib/intrinsics/src/atomic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const atomic_integer_types = [UInt32, Int32]
# TODO: 64-bit atomics with ZE_DEVICE_MODULE_FLAG_INT64_ATOMICS
# TODO: additional floating-point atomics with ZE_extension_float_atomics
const atomic_memory_types = [AS.Local, AS.Global]
const atomic_memory_types = [AS.Workgroup, AS.CrossWorkgroup]


# generically typed
Expand Down
4 changes: 2 additions & 2 deletions lib/intrinsics/src/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ for gentype in generic_types
cosval = Ref{$gentype}()
sinval = GC.@preserve cosval begin
ptr = Base.unsafe_convert(Ptr{$gentype}, cosval)
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Private}, ptr)
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Private}), x, llvm_ptr)
llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Function}, ptr)
@builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Function}), x, llvm_ptr)
end
return sinval, cosval[]
end
Expand Down
6 changes: 3 additions & 3 deletions lib/intrinsics/src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
Context() do ctx
# XXX: as long as LLVMPtr is emitted as i8*, it doesn't make sense to type the GV
eltyp = convert(LLVMType, LLVM.Int8Type())
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Local})
T_ptr = convert(LLVMType, LLVMPtr{T,AS.Workgroup})

# create a function
llvm_f, _ = create_function(T_ptr)

# create the global variable
mod = LLVM.parent(llvm_f)
gv_typ = LLVM.ArrayType(eltyp, len * sizeof(T))
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Local)
gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Workgroup)
if len > 0
linkage!(gv, LLVM.API.LLVMInternalLinkage)
initializer!(gv, null(gv_typ))
Expand All @@ -33,6 +33,6 @@
ret!(builder, untyped_ptr)
end

call_function(llvm_f, LLVMPtr{T,AS.Local})
call_function(llvm_f, LLVMPtr{T,AS.Workgroup})
end
end
19 changes: 11 additions & 8 deletions lib/intrinsics/src/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ export AS

module AS

const Private = 0
const Global = 1
const Constant = 2
const Local = 3
const Generic = 4
const Input = 5
const Output = 6
const Count = 7
const Function = 0
const CrossWorkgroup = 1
const UniformConstant = 2
const Workgroup = 3
const Generic = 4
const DeviceOnlyINTEL = 5 # XXX: should be CrossWorkgroup
const HostOnlyINTEL = 6 # when USM is not supported
const Input = 7
const Output = 8
const CodeSectionINTEL = 9
const Private = 10

end
4 changes: 2 additions & 2 deletions lib/intrinsics/src/printf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
Context() do ctx
T_void = LLVM.VoidType()
T_int32 = LLVM.Int32Type()
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.Constant)
T_pint8 = LLVM.PointerType(LLVM.Int8Type(), AS.UniformConstant)

# create functions
param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types]
Expand Down Expand Up @@ -80,7 +80,7 @@ end
push!(actual_args, actual_arg)
end

str = globalstring_ptr!(builder, String(fmt); addrspace=AS.Constant)
str = globalstring_ptr!(builder, String(fmt); addrspace=AS.UniformConstant)

# invoke printf and return
printf_typ = LLVM.FunctionType(T_int32, [T_pint8]; vararg=true)
Expand Down
10 changes: 5 additions & 5 deletions lib/intrinsics/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ macro builtin_ccall(name, ret, argtypes, args...)
elt, as = T.parameters

# mangle address space
ASstr = if as == AS.Global
ASstr = if as == AS.CrossWorkgroup
"CLglobal"
#elseif as == AS.Global_device
# "CLdevice"
#elseif as == AS.Global_host
# "CLhost"
elseif as == AS.Local
elseif as == AS.Workgroup
"CLlocal"
elseif as == AS.Constant
elseif as == AS.UniformConstant
"CLconstant"
elseif as == AS.Private
elseif as == AS.Function
"CLprivate"
elseif as == AS.Generic
"CLgeneric"
else
error("Unknown address space $AS")
error("Unknown address space $as")
end

# encode as vendor qualifier
Expand Down
80 changes: 58 additions & 22 deletions lib/intrinsics/src/work_item.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,64 @@
# Work-Item Functions

export get_work_dim,
get_global_size, get_global_id,
get_local_size, get_enqueued_local_size, get_local_id,
get_num_groups, get_group_id,
get_global_offset,
get_global_linear_id, get_local_linear_id
#
# https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables

# NOTE: these functions now unsafely truncate to Int to avoid top bit checks.
# we should probably use range metadata instead.

@device_function get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int

@device_function get_global_size(dimindx::Integer = 1u32) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx - 1u32) % Int
@device_function get_global_id(dimindx::Integer = 1u32) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1

@device_function get_local_size(dimindx::Integer = 1u32) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx - 1u32) % Int
@device_function get_enqueued_local_size(dimindx::Integer = 1u32) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx - 1u32) % Int
@device_function get_local_id(dimindx::Integer = 1u32) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1

@device_function get_num_groups(dimindx::Integer = 1u32) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx - 1u32) % Int
@device_function get_group_id(dimindx::Integer = 1u32) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx - 1u32) % Int + 1

@device_function get_global_offset(dimindx::Integer = 1u32) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx - 1u32) % Int + 1
# 1D values
for (julia_name, (spirv_name, julia_type, offset)) in [
# indices
:get_global_linear_id => (:BuiltInGlobalLinearId, Csize_t, 1),
:get_local_linear_id => (:BuiltInLocalInvocationIndex, Csize_t, 1),
:get_sub_group_id => (:BuiltInSubgroupId, UInt32, 1),
:get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, UInt32, 1),
# sizes
:get_work_dim => (:BuiltInWorkDim, UInt32, 0),
:get_sub_group_size => (:BuiltInSubgroupSize, UInt32, 0),
:get_max_sub_group_size => (:BuiltInSubgroupMaxSize, UInt32, 0),
:get_num_sub_groups => (:BuiltInNumSubgroups, UInt32, 0),
:get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, UInt32, 0)]
gvar_name = Symbol("@__spirv_$(spirv_name)")
width = sizeof(julia_type) * 8
@eval begin
export $julia_name
@device_function $julia_name() =
Base.llvmcall(
$("""$gvar_name = external addrspace($(AS.Input)) global i$(width)
define i$(width) @entry() #0 {
%val = load i$(width), i$(width) addrspace($(AS.Input))* $gvar_name
ret i$(width) %val
}
attributes #0 = { alwaysinline }
""", "entry"), $julia_type, Tuple{}) % Int + $offset
end
end

@device_function get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1
@device_function get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1
# 3D values
for (julia_name, (spirv_name, offset)) in [
# indices
:get_global_id => (:BuiltInGlobalInvocationId, 1),
:get_global_offset => (:BuiltInGlobalOffset, 1),
:get_local_id => (:BuiltInLocalInvocationId, 1),
:get_group_id => (:BuiltInWorkgroupId, 1),
# sizes
:get_global_size => (:BuiltInGlobalSize, 0),
:get_local_size => (:BuiltInWorkgroupSize, 0),
:get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0),
:get_num_groups => (:BuiltInNumWorkgroups, 0)]
gvar_name = Symbol("@__spirv_$(spirv_name)")
width = Int === Int64 ? 64 : 32
@eval begin
export $julia_name
@device_function $julia_name(dimindx::Integer=1u32) =
Base.llvmcall(
$("""$gvar_name = external addrspace($(AS.Input)) global <3 x i$(width)>
define i$(width) @entry(i$(width) %idx) #0 {
%val = load <3 x i$(width)>, <3 x i$(width)> addrspace($(AS.Input))* $gvar_name
%element = extractelement <3 x i$(width)> %val, i$(width) %idx
ret i$(width) %element
}
attributes #0 = { alwaysinline }
""", "entry"), UInt, Tuple{UInt}, UInt(dimindx - 1u32)) % Int + $offset
end
end
2 changes: 1 addition & 1 deletion src/OpenCL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OpenCL

using GPUCompiler
using LLVM, LLVM.Interop
using SPIRV_LLVM_Translator_jll
using SPIRV_LLVM_Backend_jll, SPIRV_Tools_jll
using Adapt
using Reexport
using GPUArrays
Expand Down
7 changes: 4 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ end

## interop with GPU arrays

function Base.unsafe_convert(::Type{CLDeviceArray{T, N, AS.Global}}, a::CLArray{T, N}) where {T, N}
return CLDeviceArray{T, N, AS.Global}(
size(a), reinterpret(LLVMPtr{T, AS.Global}, pointer(a)),
function Base.unsafe_convert(::Type{CLDeviceArray{T, N, AS.CrossWorkgroup}},
a::CLArray{T, N}) where {T, N}
return CLDeviceArray{T, N, AS.CrossWorkgroup}(
size(a), reinterpret(LLVMPtr{T, AS.CrossWorkgroup}, pointer(a)),
a.maxsize - a.offset * Base.elsize(a)
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end
supports_fp64 = "cl_khr_fp64" in dev.extensions

# create GPUCompiler objects
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, kwargs...)
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...)
params = OpenCLCompilerParams()
CompilerConfig(target, params; kernel, name, always_inline)
end
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function Adapt.adapt_storage(to::KernelAdaptor, buf::cl.AbstractMemory)
end
function Adapt.adapt_storage(to::KernelAdaptor, arr::CLArray{T, N}) where {T, N}
push!(to.indirect_memory, arr.data[].mem)
return Base.unsafe_convert(CLDeviceArray{T, N, AS.Global}, arr)
return Base.unsafe_convert(CLDeviceArray{T, N, AS.CrossWorkgroup}, arr)
end

# Base.RefValue isn't GPU compatible, so provide a compatible alternative
Expand Down
2 changes: 1 addition & 1 deletion src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function versioninfo(io::IO=stdout)

println(io, "Toolchain:")
println(io, " - Julia v$(VERSION)")
for jll in [cl.OpenCL_jll, SPIRV_LLVM_Translator_jll]
for jll in [cl.OpenCL_jll, SPIRV_LLVM_Backend_jll]
name = string(jll)
println(io, " - $(name[1:end-4]): $(pkgversion(jll))")
end
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
2 changes: 1 addition & 1 deletion test/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ end

@test OpenCL.return_type(identity, Tuple{Int}) === Int
@test OpenCL.return_type(sin, Tuple{Float32}) === Float32
@test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32,1,AS.Global},Int32}) === Float32
@test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32,1,AS.CrossWorkgroup},Int32}) === Float32
@test OpenCL.return_type(getindex, Tuple{Base.RefValue{Integer}}) === Integer
end

Expand Down
Loading