Skip to content

Commit 4e21315

Browse files
authored
[SPIRV] Add FloatControl2 capability (#144371)
Add handling for FPFastMathMode in SPIR-V shaders. This is a first pass that simply does a direct translation when the proper extension is available. This will unblock work for HLSL. However, it is not a full solution. The default math mode for spir-v is determined by the API. When targeting Vulkan many of the fast math options are assumed. We should do something particular when targeting Vulkan. We will also need to handle the hlsl "precise" keyword correctly when FPFastMathMode is not available. Unblockes #140739, but we are keeing it open to track the remaining issues mentioned above.
1 parent 9c0743f commit 4e21315

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
217217
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
218218
* - ``SPV_INTEL_int4``
219219
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
220+
* - ``SPV_KHR_float_controls2``
221+
- Adds ability to specify the floating-point environment in shaders. It can be used on whole modules and individual instructions.
220222

221223
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
222224

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
100100
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
101101
{"SPV_INTEL_2d_block_io",
102102
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103-
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
103+
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
104+
{"SPV_KHR_float_controls2",
105+
SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
104106

105107
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
106108
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
7070
AvoidCapabilitiesSet AvoidCaps;
7171
if (!ST.isShader())
7272
AvoidCaps.S.insert(SPIRV::Capability::Shader);
73+
else
74+
AvoidCaps.S.insert(SPIRV::Capability::Kernel);
7375

7476
VersionTuple ReqMinVer = getSymbolicOperandMinVersion(Category, i);
7577
VersionTuple ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
@@ -88,8 +90,11 @@ getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
8890
} else if (MinVerOK && MaxVerOK) {
8991
if (ReqCaps.size() == 1) {
9092
auto Cap = ReqCaps[0];
91-
if (Reqs.isCapabilityAvailable(Cap))
93+
if (Reqs.isCapabilityAvailable(Cap)) {
94+
ReqExts.append(getSymbolicOperandExtensions(
95+
SPIRV::OperandCategory::CapabilityOperand, Cap));
9296
return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
97+
}
9398
} else {
9499
// By SPIR-V specification: "If an instruction, enumerant, or other
95100
// feature specifies multiple enabling capabilities, only one such
@@ -103,8 +108,11 @@ getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
103108
UseCaps.push_back(Cap);
104109
for (size_t i = 0, Sz = UseCaps.size(); i < Sz; ++i) {
105110
auto Cap = UseCaps[i];
106-
if (i == Sz - 1 || !AvoidCaps.S.contains(Cap))
111+
if (i == Sz - 1 || !AvoidCaps.S.contains(Cap)) {
112+
ReqExts.append(getSymbolicOperandExtensions(
113+
SPIRV::OperandCategory::CapabilityOperand, Cap));
107114
return {true, {Cap}, ReqExts, ReqMinVer, ReqMaxVer};
115+
}
108116
}
109117
}
110118
}
@@ -1975,6 +1983,14 @@ static unsigned getFastMathFlags(const MachineInstr &I) {
19751983
return Flags;
19761984
}
19771985

1986+
static bool isFastMathMathModeAvailable(const SPIRVSubtarget &ST) {
1987+
if (ST.isKernel())
1988+
return true;
1989+
if (ST.getSPIRVVersion() < VersionTuple(1, 2))
1990+
return false;
1991+
return ST.canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2);
1992+
}
1993+
19781994
static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
19791995
const SPIRVInstrInfo &TII,
19801996
SPIRV::RequirementHandler &Reqs) {
@@ -1998,8 +2014,12 @@ static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
19982014
unsigned FMFlags = getFastMathFlags(I);
19992015
if (FMFlags == SPIRV::FPFastMathMode::None)
20002016
return;
2001-
Register DstReg = I.getOperand(0).getReg();
2002-
buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
2017+
2018+
if (isFastMathMathModeAvailable(ST)) {
2019+
Register DstReg = I.getOperand(0).getReg();
2020+
buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode,
2021+
{FMFlags});
2022+
}
20032023
}
20042024

20052025
// Walk all functions and add decorations related to MI flags.

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
319319
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
320320
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
321321
defm SPV_INTEL_int4 : ExtensionOperand<123>;
322+
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
322323

323324
//===----------------------------------------------------------------------===//
324325
// Multiclass used to define Capabilities enum values and at the same time
@@ -489,6 +490,8 @@ defm DotProductInput4x8Bit : CapabilityOperand<6017, 0x10600, 0, [SPV_KHR_intege
489490
defm DotProductInput4x8BitPacked : CapabilityOperand<6018, 0x10600, 0, [SPV_KHR_integer_dot_product], []>;
490491
defm DotProduct : CapabilityOperand<6019, 0x10600, 0, [SPV_KHR_integer_dot_product], []>;
491492
defm GroupNonUniformRotateKHR : CapabilityOperand<6026, 0, 0, [SPV_KHR_subgroup_rotate], [GroupNonUniform]>;
493+
defm FloatControls2
494+
: CapabilityOperand<6029, 0x10200, 0, [SPV_KHR_float_controls2], []>;
492495
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
493496
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
494497
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
@@ -1239,7 +1242,7 @@ defm XfbBuffer : DecorationOperand<36, 0, 0, [], [TransformFeedback]>;
12391242
defm XfbStride : DecorationOperand<37, 0, 0, [], [TransformFeedback]>;
12401243
defm FuncParamAttr : DecorationOperand<38, 0, 0, [], [Kernel]>;
12411244
defm FPRoundingMode : DecorationOperand<39, 0, 0, [], []>;
1242-
defm FPFastMathMode : DecorationOperand<40, 0, 0, [], [Kernel]>;
1245+
defm FPFastMathMode : DecorationOperand<40, 0, 0, [], [Kernel, FloatControls2]>;
12431246
defm LinkageAttributes : DecorationOperand<41, 0, 0, [], [Linkage]>;
12441247
defm NoContraction : DecorationOperand<42, 0, 0, [], [Shader]>;
12451248
defm InputAttachmentIndex : DecorationOperand<43, 0, 0, [], [InputAttachment]>;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: llc -O0 -mtriple=spirv1.6-vulkan1.3-compute %s -o - | FileCheck %s --check-prefix=CHECK-NOEXT
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-compute %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
3+
4+
; RUN: llc -O0 -mtriple=spirv1.6-vulkan1.3-compute -spirv-ext=+SPV_KHR_float_controls2 %s -o - | FileCheck %s --check-prefix=CHECK-EXT
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-compute -spirv-ext=+SPV_KHR_float_controls2 %s -o - -filetype=obj | spirv-val --target-env vulkan1.3 %}
6+
7+
; CHECK-NOEXT-NOT: OpDecorate FPFastMathMode
8+
9+
; CHECK-EXT: OpCapability FloatControls2
10+
; CHECK-EXT: OpExtension "SPV_KHR_float_controls2"
11+
; CHECK-EXT: OpDecorate {{%[0-9]+}} FPFastMathMode NotNaN|NotInf|NSZ|AllowRecip|Fast
12+
13+
define hidden spir_func float @foo(float %0) local_unnamed_addr {
14+
%2 = fmul reassoc nnan ninf nsz arcp afn float %0, 2.000000e+00
15+
ret float %2
16+
}
17+
18+
define void @main() local_unnamed_addr #1 {
19+
ret void
20+
}
21+
22+
attributes #1 = { "hlsl.numthreads"="8,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)