Skip to content

Commit c9fb0be

Browse files
authored
SPIR-V: Disallow Float16/Float64 if unsupported. (#429)
1 parent 341e2e8 commit c9fb0be

File tree

5 files changed

+67
-14
lines changed

5 files changed

+67
-14
lines changed

src/metal.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,8 @@ function process_entry!(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module,
6464
end
6565

6666
function validate_module(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
67-
errors = IRError[]
68-
69-
T_double = LLVM.DoubleType(context(mod))
70-
71-
for fun in functions(mod), bb in blocks(fun), inst in instructions(bb)
72-
if value_type(inst) == T_double || any(param->value_type(param) == T_double, operands(inst))
73-
bt = backtrace(inst)
74-
push!(errors, ("use of double floating-point value", bt, inst))
75-
end
76-
end
77-
78-
return errors
67+
# Metal never supports double precision
68+
check_ir_values(mod, LLVM.DoubleType(context(mod)))
7969
end
8070

8171
# TODO: why is this done in finish_module? maybe just in process_entry?

src/spirv.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ const SPIRV_Tools_jll = LazyModule("SPIRV_Tools_jll", UUID("6ac6d60f-d740-5983-9
1313
export SPIRVCompilerTarget
1414

1515
Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget
16+
supports_fp16::Bool = true
17+
supports_fp64::Bool = true
1618
end
1719

1820
llvm_triple(::SPIRVCompilerTarget) = Int===Int64 ? "spir64-unknown-unknown" : "spirv-unknown-unknown"
@@ -73,6 +75,20 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
7375
return entry
7476
end
7577

78+
function validate_module(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
79+
errors = IRError[]
80+
81+
# support for half and double depends on the target
82+
if !job.config.target.supports_fp16
83+
append!(errors, check_ir_values(mod, LLVM.HalfType(context(mod))))
84+
end
85+
if !job.config.target.supports_fp64
86+
append!(errors, check_ir_values(mod, LLVM.DoubleType(context(mod))))
87+
end
88+
89+
return errors
90+
end
91+
7692
@unlocked function mcgen(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
7793
format=LLVM.API.LLVMAssemblyFile)
7894
# The SPIRV Tools don't handle Julia's debug info, rejecting DW_LANG_Julia...

src/validation.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,17 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
261261

262262
return errors
263263
end
264+
265+
# helper function to check if a LLVM module uses values of a certain type
266+
function check_ir_values(mod::LLVM.Module, T_bad::LLVMType)
267+
errors = IRError[]
268+
269+
for fun in functions(mod), bb in blocks(fun), inst in instructions(bb)
270+
if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst))
271+
bt = backtrace(inst)
272+
push!(errors, ("unsupported use of $(T_bad) value", bt, inst))
273+
end
274+
end
275+
276+
return errors
277+
end

test/definitions/spirv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ end
88
# create a SPIRV-based test compiler, and generate reflection methods for it
99

1010
function spirv_job(@nospecialize(func), @nospecialize(types);
11-
kernel::Bool=false, always_inline=false, kwargs...)
11+
kernel::Bool=false, always_inline=false,
12+
supports_fp16=true, supports_fp64=true, kwargs...)
1213
source = methodinstance(typeof(func), Base.to_tuple_type(types))
13-
target = SPIRVCompilerTarget()
14+
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64)
1415
params = TestCompilerParams()
1516
config = CompilerConfig(target, params; kernel, always_inline)
1617
CompilerJob(source, config), kwargs

test/spirv.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,38 @@ end
3737
end
3838
end
3939

40+
@testset "unsupported type detection" begin
41+
function kernel(ptr, val)
42+
unsafe_store!(ptr, val)
43+
return
44+
end
45+
46+
ir = sprint(io->spirv_code_llvm(io, kernel, Tuple{Ptr{Float16}, Float16}; validate=true))
47+
@test occursin("store half", ir)
48+
49+
ir = sprint(io->spirv_code_llvm(io, kernel, Tuple{Ptr{Float32}, Float32}; validate=true))
50+
@test occursin("store float", ir)
51+
52+
ir = sprint(io->spirv_code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}; validate=true))
53+
@test occursin("store double", ir)
54+
55+
@test_throws_message(InvalidIRError,
56+
spirv_code_llvm(devnull, kernel, Tuple{Ptr{Float16}, Float16};
57+
supports_fp16=false, validate=true)) do msg
58+
occursin("unsupported unsupported use of half value", msg) &&
59+
occursin("[1] unsafe_store!", msg) &&
60+
occursin(r"\[2\] .*kernel", msg)
61+
end
62+
63+
@test_throws_message(InvalidIRError,
64+
spirv_code_llvm(devnull, kernel, Tuple{Ptr{Float64}, Float64};
65+
supports_fp64=false, validate=true)) do msg
66+
occursin("unsupported unsupported use of double value", msg) &&
67+
occursin("[1] unsafe_store!", msg) &&
68+
occursin(r"\[2\] .*kernel", msg)
69+
end
70+
end
71+
4072
end
4173

4274
############################################################################################

0 commit comments

Comments
 (0)