Skip to content

Commit 435d8b1

Browse files
authored
Reland [SPIR-V] Support SPV_INTEL_int4 extension (llvm#141279)
This relands llvm#141031 This change ensures generated SPIR-V is valid and passes machine verification: ``` *** Bad machine code: inconsistent constant size *** - function: foo - basic block: %bb.1 entry (0x9ec9298) - instruction: %12:iid(s8) = G_CONSTANT i4 1 ``` That is done by promoting `G_CONSTANT` instructions with small integer types (e.g., `i4`) to `i8` if no extensions for "special" integer types are enabled.
1 parent d03f30f commit 435d8b1

File tree

9 files changed

+130
-17
lines changed

9 files changed

+130
-17
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
215215
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
216216
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
217217
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
218+
* - ``SPV_INTEL_int4``
219+
- Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
218220

219221
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
220222

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
9999
{"SPV_INTEL_ternary_bitwise_function",
100100
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
101101
{"SPV_INTEL_2d_block_io",
102-
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
102+
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103+
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
103104

104105
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
105106
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
154154
report_fatal_error("Unsupported integer width!");
155155
const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
156156
if (ST.canUseExtension(
157-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
157+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
158+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
158159
return Width;
159160
if (Width <= 8)
160161
Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
174175
const SPIRVSubtarget &ST =
175176
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
176177
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
177-
if ((!isPowerOf2_32(Width) || Width < 8) &&
178-
ST.canUseExtension(
179-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
178+
if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
179+
MIRBuilder.buildInstr(SPIRV::OpExtension)
180+
.addImm(SPIRV::Extension::SPV_INTEL_int4);
181+
MIRBuilder.buildInstr(SPIRV::OpCapability)
182+
.addImm(SPIRV::Capability::Int4TypeINTEL);
183+
} else if ((!isPowerOf2_32(Width) || Width < 8) &&
184+
ST.canUseExtension(
185+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
180186
MIRBuilder.buildInstr(SPIRV::OpExtension)
181187
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
182188
MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
15631569
const MachineInstr *NewMI =
15641570
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
15651571
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
1572+
const Type *ET = getTypeForSPIRVType(ElemType);
1573+
if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
1574+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
1575+
.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
1576+
MIRBuilder.buildInstr(SPIRV::OpCapability)
1577+
.addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
1578+
}
15661579
return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
15671580
.addDef(createTypeVReg(MIRBuilder))
15681581
.addUse(getSPIRVTypeID(ElemType))

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
128128
bool IsExtendedInts =
129129
ST.canUseExtension(
130130
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
131-
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
131+
ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
132+
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
132133
auto extendedScalarsAndVectors =
133134
[IsExtendedInts](const LegalityQuery &Query) {
134135
const LLT Ty = Query.Types[0];

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,16 +380,31 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
380380
// To support current approach and limitations wrt. bit width here we widen a
381381
// scalar register with a bit width greater than 1 to valid sizes and cap it to
382382
// 64 width.
383-
static void widenScalarLLTNextPow2(Register Reg, MachineRegisterInfo &MRI) {
383+
static unsigned widenBitWidthToNextPow2(unsigned BitWidth) {
384+
if (BitWidth == 1)
385+
return 1; // No need to widen 1-bit values
386+
return std::min(std::max(1u << Log2_32_Ceil(BitWidth), 8u), 64u);
387+
}
388+
389+
static void widenScalarType(Register Reg, MachineRegisterInfo &MRI) {
384390
LLT RegType = MRI.getType(Reg);
385391
if (!RegType.isScalar())
386392
return;
387-
unsigned Sz = RegType.getScalarSizeInBits();
388-
if (Sz == 1)
389-
return;
390-
unsigned NewSz = std::min(std::max(1u << Log2_32_Ceil(Sz), 8u), 64u);
391-
if (NewSz != Sz)
392-
MRI.setType(Reg, LLT::scalar(NewSz));
393+
unsigned CurrentWidth = RegType.getScalarSizeInBits();
394+
unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
395+
if (NewWidth != CurrentWidth)
396+
MRI.setType(Reg, LLT::scalar(NewWidth));
397+
}
398+
399+
static void widenCImmType(MachineOperand &MOP) {
400+
const ConstantInt *CImmVal = MOP.getCImm();
401+
unsigned CurrentWidth = CImmVal->getBitWidth();
402+
unsigned NewWidth = widenBitWidthToNextPow2(CurrentWidth);
403+
if (NewWidth != CurrentWidth) {
404+
// Replace the immediate value with the widened version
405+
MOP.setCImm(ConstantInt::get(CImmVal->getType()->getContext(),
406+
CImmVal->getValue().zextOrTrunc(NewWidth)));
407+
}
393408
}
394409

395410
static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
@@ -492,7 +507,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
492507
bool IsExtendedInts =
493508
ST->canUseExtension(
494509
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
495-
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
510+
ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
511+
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
496512

497513
for (MachineBasicBlock *MBB : post_order(&MF)) {
498514
if (MBB->empty())
@@ -505,10 +521,13 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
505521
unsigned MIOp = MI.getOpcode();
506522

507523
if (!IsExtendedInts) {
508-
// validate bit width of scalar registers
509-
for (const auto &MOP : MI.operands())
524+
// validate bit width of scalar registers and constant immediates
525+
for (auto &MOP : MI.operands()) {
510526
if (MOP.isReg())
511-
widenScalarLLTNextPow2(MOP.getReg(), MRI);
527+
widenScalarType(MOP.getReg(), MRI);
528+
else if (MOP.isCImm())
529+
widenCImmType(MOP);
530+
}
512531
}
513532

514533
if (isSpvIntrinsic(MI, Intrinsic::spv_assign_ptr_type)) {

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
317317
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
318318
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
319319
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
320+
defm SPV_INTEL_int4 : ExtensionOperand<123>;
320321

321322
//===----------------------------------------------------------------------===//
322323
// Multiclass used to define Capabilities enum values and at the same time
@@ -522,6 +523,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
522523
defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
523524
defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
524525
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
526+
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
527+
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
525528

526529
//===----------------------------------------------------------------------===//
527530
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: Capability Int4TypeINTEL
5+
; CHECK-DAG: Capability CooperativeMatrixKHR
6+
; CHECK-DAG: Extension "SPV_INTEL_int4"
7+
; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
8+
; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
9+
10+
; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
11+
; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
12+
; CHECK: CompositeConstruct %[[#CoopMatTy]]
13+
14+
define spir_kernel void @foo() {
15+
entry:
16+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
17+
ret void
18+
}
19+
20+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
2+
3+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
4+
; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made
5+
; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
6+
7+
; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
8+
; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
9+
; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
10+
; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
11+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
12+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
13+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
14+
15+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
16+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
17+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
18+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
19+
20+
define spir_kernel void @foo() {
21+
entry:
22+
%0 = alloca i4
23+
store i4 1, ptr %0
24+
%1 = load i4, ptr %0
25+
call spir_func void @boo(i4 %1)
26+
ret void
27+
}
28+
29+
declare spir_func void @boo(i4)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
2+
; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: Capability Int4TypeINTEL
5+
; CHECK: Extension "SPV_INTEL_int4"
6+
; CHECK: %[[#Int4:]] = OpTypeInt 4 0
7+
; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
8+
; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
9+
; CHECK: %[[#Const:]] = OpConstant %[[#Int4]] 1
10+
11+
; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
12+
; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
13+
; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
14+
; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
15+
16+
define spir_kernel void @foo() {
17+
entry:
18+
%0 = alloca i4
19+
store i4 1, ptr %0
20+
%1 = load i4, ptr %0
21+
call spir_func void @boo(i4 %1)
22+
ret void
23+
}
24+
25+
declare spir_func void @boo(i4)

0 commit comments

Comments
 (0)