Skip to content

Switch to the LLVM SPIR-V back-end. #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 10, 2025
Merged

Switch to the LLVM SPIR-V back-end. #285

merged 11 commits into from
Jun 10, 2025

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Feb 13, 2025

Same as JuliaGPU/oneAPI.jl#491, but I figured it's easier to do the experimentation here as the SPIRVIntrinsics submodule is part of this repository.

@maleadt
Copy link
Member Author

maleadt commented Feb 18, 2025

Currently blocked on pocl/pocl#1799.

MWE:

function double_print_kernel()
    OpenCL.@print "foo"
    OpenCL.@print "barbar"
    return
end

function main()
    @opencl double_print_kernel()
end

Reduced from:

function orig()
    x = CLArray(zeros(Float32, (1, 1)))
    y = CLArray(rand(Float32, (1)))
    x[:, 1] = y
end

EDIT: upstream argued that this is an LLVM bug.

@maleadt
Copy link
Member Author

maleadt commented Feb 20, 2025

Also blocked on llvm/llvm-project#127977

MWE: CLArray{Float32}(undef, 5, 5) + 1f0 * I

Reduced from:

function orig()
    T1 = T2 = Float32
    AT = CLArray
    f = identity
    x = ones(T1, 5, 5)
    y = AT(x)

    xw = f(x)
    yw = f(y)

    J = one(T2) * I

    @allowscalar collect(xw + J)  collect(yw + J)
end

@maleadt maleadt force-pushed the tb/llvm_spirv_backend branch from 52bc6cf to fdc1656 Compare May 20, 2025 07:43
@maleadt
Copy link
Member Author

maleadt commented May 20, 2025

Rebased and updated the JLL. On 7.0 the printf failure has disappeared, but another issue (which I didn't reduce yet) remains:

Call parameter type does not match function signature!
  %192 = call i32 @llvm.spv.track.constant.i32.i32(i32 undef, metadata i32 undef)
 ptr  %194 = call i32 (i32, ptr, ...) @llvm.spv.insertv.p0(i32 %191, i32 %192, i32 %193)
in function julia_ArgumentError_91762

Copy link

codecov bot commented May 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 73.86%. Comparing base (9247b80) to head (a85c235).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #285   +/-   ##
=======================================
  Coverage   73.86%   73.86%           
=======================================
  Files          12       12           
  Lines         616      616           
=======================================
  Hits          455      455           
  Misses        161      161           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@maleadt maleadt force-pushed the tb/llvm_spirv_backend branch from fdc1656 to c488e6e Compare June 10, 2025 10:27
@maleadt maleadt marked this pull request as ready for review June 10, 2025 11:26
Copy link
Contributor

github-actions bot commented Jun 10, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/intrinsics/src/math.jl b/lib/intrinsics/src/math.jl
index d51d603..457983f 100644
--- a/lib/intrinsics/src/math.jl
+++ b/lib/intrinsics/src/math.jl
@@ -94,8 +94,8 @@ for gentype in generic_types
     cosval = Ref{$gentype}()
     sinval = GC.@preserve cosval begin
         ptr = Base.unsafe_convert(Ptr{$gentype}, cosval)
-        llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Function}, ptr)
-        @builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Function}), x, llvm_ptr)
+                llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Function}, ptr)
+                @builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Function}), x, llvm_ptr)
     end
     return sinval, cosval[]
 end
diff --git a/lib/intrinsics/src/memory.jl b/lib/intrinsics/src/memory.jl
index da5c5b2..40b398b 100644
--- a/lib/intrinsics/src/memory.jl
+++ b/lib/intrinsics/src/memory.jl
@@ -5,7 +5,7 @@
     Context() do ctx
         # XXX: as long as LLVMPtr is emitted as i8*, it doesn't make sense to type the GV
         eltyp = convert(LLVMType, LLVM.Int8Type())
-        T_ptr = convert(LLVMType, LLVMPtr{T,AS.Workgroup})
+        T_ptr = convert(LLVMType, LLVMPtr{T, AS.Workgroup})
 
         # create a function
         llvm_f, _ = create_function(T_ptr)
@@ -33,6 +33,6 @@
             ret!(builder, untyped_ptr)
         end
 
-        call_function(llvm_f, LLVMPtr{T,AS.Workgroup})
+        call_function(llvm_f, LLVMPtr{T, AS.Workgroup})
     end
 end
diff --git a/lib/intrinsics/src/pointer.jl b/lib/intrinsics/src/pointer.jl
index b84d2d1..e2afcfe 100644
--- a/lib/intrinsics/src/pointer.jl
+++ b/lib/intrinsics/src/pointer.jl
@@ -4,16 +4,16 @@ export AS
 
 module AS
 
-const Function          = 0
-const CrossWorkgroup    = 1
-const UniformConstant   = 2
-const Workgroup         = 3
-const Generic           = 4
-const DeviceOnlyINTEL   = 5 # XXX: should be CrossWorkgroup
-const HostOnlyINTEL     = 6 #      when USM is not supported
-const Input             = 7
-const Output            = 8
-const CodeSectionINTEL  = 9
-const Private           = 10
+    const Function = 0
+    const CrossWorkgroup = 1
+    const UniformConstant = 2
+    const Workgroup = 3
+    const Generic = 4
+    const DeviceOnlyINTEL = 5 # XXX: should be CrossWorkgroup
+    const HostOnlyINTEL = 6 #      when USM is not supported
+    const Input = 7
+    const Output = 8
+    const CodeSectionINTEL = 9
+    const Private = 10
 
 end
diff --git a/lib/intrinsics/src/printf.jl b/lib/intrinsics/src/printf.jl
index 203a966..aaf5969 100644
--- a/lib/intrinsics/src/printf.jl
+++ b/lib/intrinsics/src/printf.jl
@@ -80,7 +80,7 @@ end
                 push!(actual_args, actual_arg)
             end
 
-            str = globalstring_ptr!(builder, String(fmt); addrspace=AS.UniformConstant)
+            str = globalstring_ptr!(builder, String(fmt); addrspace = AS.UniformConstant)
 
             # invoke printf and return
             printf_typ = LLVM.FunctionType(T_int32, [T_pint8]; vararg=true)
diff --git a/lib/intrinsics/src/work_item.jl b/lib/intrinsics/src/work_item.jl
index bbe85ad..3fc17fc 100644
--- a/lib/intrinsics/src/work_item.jl
+++ b/lib/intrinsics/src/work_item.jl
@@ -8,57 +8,65 @@
 # 1D values
 for (julia_name, (spirv_name, julia_type, offset)) in [
         # indices
-        :get_global_linear_id           => (:BuiltInGlobalLinearId, Csize_t, 1),
-        :get_local_linear_id            => (:BuiltInLocalInvocationIndex, Csize_t, 1),
-        :get_sub_group_id               => (:BuiltInSubgroupId, UInt32, 1),
-        :get_sub_group_local_id         => (:BuiltInSubgroupLocalInvocationId, UInt32, 1),
+        :get_global_linear_id => (:BuiltInGlobalLinearId, Csize_t, 1),
+        :get_local_linear_id => (:BuiltInLocalInvocationIndex, Csize_t, 1),
+        :get_sub_group_id => (:BuiltInSubgroupId, UInt32, 1),
+        :get_sub_group_local_id => (:BuiltInSubgroupLocalInvocationId, UInt32, 1),
         # sizes
-        :get_work_dim                   => (:BuiltInWorkDim, UInt32, 0),
-        :get_sub_group_size             => (:BuiltInSubgroupSize, UInt32, 0),
-        :get_max_sub_group_size         => (:BuiltInSubgroupMaxSize, UInt32, 0),
-        :get_num_sub_groups             => (:BuiltInNumSubgroups, UInt32, 0),
-        :get_enqueued_num_sub_groups    => (:BuiltInNumEnqueuedSubgroups, UInt32, 0)]
+        :get_work_dim => (:BuiltInWorkDim, UInt32, 0),
+        :get_sub_group_size => (:BuiltInSubgroupSize, UInt32, 0),
+        :get_max_sub_group_size => (:BuiltInSubgroupMaxSize, UInt32, 0),
+        :get_num_sub_groups => (:BuiltInNumSubgroups, UInt32, 0),
+        :get_enqueued_num_sub_groups => (:BuiltInNumEnqueuedSubgroups, UInt32, 0),
+    ]
     gvar_name = Symbol("@__spirv_$(spirv_name)")
     width = sizeof(julia_type) * 8
     @eval begin
         export $julia_name
         @device_function $julia_name() =
             Base.llvmcall(
-                $("""$gvar_name = external addrspace($(AS.Input)) global i$(width)
+            $(
+                """$gvar_name = external addrspace($(AS.Input)) global i$(width)
                      define i$(width) @entry() #0 {
                          %val = load i$(width), i$(width) addrspace($(AS.Input))* $gvar_name
                          ret i$(width) %val
                      }
                      attributes #0 = { alwaysinline }
-                """, "entry"), $julia_type, Tuple{}) % Int + $offset
+                """, "entry",
+            ), $julia_type, Tuple{}
+        ) % Int + $offset
     end
 end
 
 # 3D values
 for (julia_name, (spirv_name, offset)) in [
         # indices
-        :get_global_id              => (:BuiltInGlobalInvocationId, 1),
-        :get_global_offset          => (:BuiltInGlobalOffset, 1),
-        :get_local_id               => (:BuiltInLocalInvocationId, 1),
-        :get_group_id               => (:BuiltInWorkgroupId, 1),
+        :get_global_id => (:BuiltInGlobalInvocationId, 1),
+        :get_global_offset => (:BuiltInGlobalOffset, 1),
+        :get_local_id => (:BuiltInLocalInvocationId, 1),
+        :get_group_id => (:BuiltInWorkgroupId, 1),
         # sizes
-        :get_global_size            => (:BuiltInGlobalSize, 0),
-        :get_local_size             => (:BuiltInWorkgroupSize, 0),
-        :get_enqueued_local_size    => (:BuiltInEnqueuedWorkgroupSize, 0),
-        :get_num_groups             => (:BuiltInNumWorkgroups, 0)]
+        :get_global_size => (:BuiltInGlobalSize, 0),
+        :get_local_size => (:BuiltInWorkgroupSize, 0),
+        :get_enqueued_local_size => (:BuiltInEnqueuedWorkgroupSize, 0),
+        :get_num_groups => (:BuiltInNumWorkgroups, 0),
+    ]
     gvar_name = Symbol("@__spirv_$(spirv_name)")
     width = Int === Int64 ? 64 : 32
     @eval begin
         export $julia_name
-        @device_function $julia_name(dimindx::Integer=1u32) =
+        @device_function $julia_name(dimindx::Integer = 1u32) =
             Base.llvmcall(
-                $("""$gvar_name = external addrspace($(AS.Input)) global <3 x i$(width)>
+            $(
+                """$gvar_name = external addrspace($(AS.Input)) global <3 x i$(width)>
                      define i$(width) @entry(i$(width) %idx) #0 {
                          %val = load <3 x i$(width)>, <3 x i$(width)> addrspace($(AS.Input))* $gvar_name
                          %element = extractelement <3 x i$(width)> %val, i$(width) %idx
                          ret i$(width) %element
                      }
                      attributes #0 = { alwaysinline }
-                """, "entry"), UInt, Tuple{UInt}, UInt(dimindx - 1u32)) % Int + $offset
+                """, "entry",
+            ), UInt, Tuple{UInt}, UInt(dimindx - 1u32)
+        ) % Int + $offset
     end
 end
diff --git a/src/array.jl b/src/array.jl
index fa8d157..11fad2f 100644
--- a/src/array.jl
+++ b/src/array.jl
@@ -299,8 +299,10 @@ end
 
 ## interop with GPU arrays
 
-function Base.unsafe_convert(::Type{CLDeviceArray{T, N, AS.CrossWorkgroup}},
-                             a::CLArray{T, N}) where {T, N}
+function Base.unsafe_convert(
+        ::Type{CLDeviceArray{T, N, AS.CrossWorkgroup}},
+        a::CLArray{T, N}
+    ) where {T, N}
     return CLDeviceArray{T, N, AS.CrossWorkgroup}(
         size(a), reinterpret(LLVMPtr{T, AS.CrossWorkgroup}, pointer(a)),
         a.maxsize - a.offset * Base.elsize(a)
diff --git a/src/compiler/compilation.jl b/src/compiler/compilation.jl
index 49c50e0..cefeb29 100644
--- a/src/compiler/compilation.jl
+++ b/src/compiler/compilation.jl
@@ -47,7 +47,7 @@ end
     supports_fp64 = "cl_khr_fp64" in dev.extensions
 
     # create GPUCompiler objects
-    target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...)
+    target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate = true, kwargs...)
     params = OpenCLCompilerParams()
     CompilerConfig(target, params; kernel, name, always_inline)
 end
diff --git a/test/execution.jl b/test/execution.jl
index 0936d0d..8b31edb 100644
--- a/test/execution.jl
+++ b/test/execution.jl
@@ -93,7 +93,7 @@ end
 
     @test OpenCL.return_type(identity, Tuple{Int}) === Int
     @test OpenCL.return_type(sin, Tuple{Float32}) === Float32
-    @test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32,1,AS.CrossWorkgroup},Int32}) === Float32
+            @test OpenCL.return_type(getindex, Tuple{CLDeviceArray{Float32, 1, AS.CrossWorkgroup}, Int32}) === Float32
     @test OpenCL.return_type(getindex, Tuple{Base.RefValue{Integer}}) === Integer
 end
 

@maleadt
Copy link
Member Author

maleadt commented Jun 10, 2025

CI failures unrelated.

@maleadt maleadt merged commit ad00c10 into master Jun 10, 2025
14 of 16 checks passed
@maleadt maleadt deleted the tb/llvm_spirv_backend branch June 10, 2025 12:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant