Skip to content

Commit 8f30b62

Browse files
[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (llvm#83443)
This PR is to add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc) and OpenCL extension cl_intel_bfloat16_conversions (https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_bfloat16_conversions.html).
1 parent eaf0d82 commit 8f30b62

13 files changed

+246
-9
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ struct ConvertBuiltin {
134134
bool IsDestinationSigned;
135135
bool IsSaturated;
136136
bool IsRounded;
137+
bool IsBfloat16;
137138
FPRoundingMode::FPRoundingMode RoundingMode;
138139
};
139140

@@ -1986,6 +1987,8 @@ static bool generateConvertInst(const StringRef DemangledCall,
19861987
SPIRV::Decoration::FPRoundingMode,
19871988
{(unsigned)Builtin->RoundingMode});
19881989

1990+
std::string NeedExtMsg; // no errors if empty
1991+
bool IsRightComponentsNumber = true; // check if input/output accepts vectors
19891992
unsigned Opcode = SPIRV::OpNop;
19901993
if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
19911994
// Int -> ...
@@ -2000,23 +2003,61 @@ static bool generateConvertInst(const StringRef DemangledCall,
20002003
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
20012004
SPIRV::OpTypeFloat)) {
20022005
// Int -> Float
2003-
bool IsSourceSigned =
2004-
DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
2005-
Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
2006+
if (Builtin->IsBfloat16) {
2007+
const auto *ST = static_cast<const SPIRVSubtarget *>(
2008+
&MIRBuilder.getMF().getSubtarget());
2009+
if (!ST->canUseExtension(
2010+
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2011+
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2012+
IsRightComponentsNumber =
2013+
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2014+
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2015+
Opcode = SPIRV::OpConvertBF16ToFINTEL;
2016+
} else {
2017+
bool IsSourceSigned =
2018+
DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
2019+
Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
2020+
}
20062021
}
20072022
} else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
20082023
SPIRV::OpTypeFloat)) {
20092024
// Float -> ...
2010-
if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt))
2025+
if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
20112026
// Float -> Int
2012-
Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
2013-
: SPIRV::OpConvertFToU;
2014-
else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2015-
SPIRV::OpTypeFloat))
2027+
if (Builtin->IsBfloat16) {
2028+
const auto *ST = static_cast<const SPIRVSubtarget *>(
2029+
&MIRBuilder.getMF().getSubtarget());
2030+
if (!ST->canUseExtension(
2031+
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2032+
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2033+
IsRightComponentsNumber =
2034+
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2035+
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2036+
Opcode = SPIRV::OpConvertFToBF16INTEL;
2037+
} else {
2038+
Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
2039+
: SPIRV::OpConvertFToU;
2040+
}
2041+
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
2042+
SPIRV::OpTypeFloat)) {
20162043
// Float -> Float
20172044
Opcode = SPIRV::OpFConvert;
2045+
}
20182046
}
20192047

2048+
if (!NeedExtMsg.empty()) {
2049+
std::string DiagMsg = std::string(Builtin->Name) +
2050+
": the builtin requires the following SPIR-V "
2051+
"extension: " +
2052+
NeedExtMsg;
2053+
report_fatal_error(DiagMsg.c_str(), false);
2054+
}
2055+
if (!IsRightComponentsNumber) {
2056+
std::string DiagMsg =
2057+
std::string(Builtin->Name) +
2058+
": result and argument must have the same number of components";
2059+
report_fatal_error(DiagMsg.c_str(), false);
2060+
}
20202061
assert(Opcode != SPIRV::OpNop &&
20212062
"Conversion between the types not implemented!");
20222063

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
11771177
bit IsDestinationSigned = !eq(!find(name, "convert_u"), -1);
11781178
bit IsSaturated = !not(!eq(!find(name, "_sat"), -1));
11791179
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
1180+
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
1181+
!not(!eq(!find(name, "bfloat16"), -1)));
11801182
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
11811183
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
11821184
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1187,7 +1189,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
11871189
// Table gathering all the convert builtins.
11881190
def ConvertBuiltins : GenericTable {
11891191
let FilterClass = "ConvertBuiltin";
1190-
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", "IsRounded", "RoundingMode"];
1192+
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
1193+
"IsRounded", "IsBfloat16", "RoundingMode"];
11911194
string TypeOf_Set = "InstructionSet";
11921195
string TypeOf_RoundingMode = "FPRoundingMode";
11931196
}
@@ -1229,6 +1232,25 @@ defm : DemangledConvertBuiltin<"convert_long", OpenCL_std>;
12291232
defm : DemangledConvertBuiltin<"convert_ulong", OpenCL_std>;
12301233
defm : DemangledConvertBuiltin<"convert_float", OpenCL_std>;
12311234

1235+
// cl_intel_bfloat16_conversions / SPV_INTEL_bfloat16_conversion
1236+
// Multiclass used to define at the same time both a demangled builtin records
1237+
// and a corresponding convert builtin records.
1238+
multiclass DemangledBF16ConvertBuiltin<string name1, string name2> {
1239+
// Create records for scalar and vector conversions.
1240+
foreach i = ["", "2", "3", "4", "8", "16"] in {
1241+
def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
1242+
def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
1243+
}
1244+
}
1245+
1246+
defm : DemangledBF16ConvertBuiltin<"bfloat16", "_as_ushort">;
1247+
defm : DemangledBF16ConvertBuiltin<"as_bfloat16", "_float">;
1248+
1249+
foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
1250+
def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
1251+
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
1252+
}
1253+
12321254
//===----------------------------------------------------------------------===//
12331255
// Class defining a vector data load/store builtin record used for lowering
12341256
// into OpExtInst instruction.

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,15 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
819819
return false;
820820
}
821821

822+
unsigned
823+
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
824+
if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
825+
return Type->getOpcode() == SPIRV::OpTypeVector
826+
? static_cast<unsigned>(Type->getOperand(2).getImm())
827+
: 1;
828+
return 0;
829+
}
830+
822831
unsigned
823832
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
824833
assert(Type && "Invalid Type pointer");

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ class SPIRVGlobalRegistry {
197197
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
198198
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
199199

200+
// Return number of elements in a vector if the given VReg is associated with
201+
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
202+
unsigned getScalarOrVectorComponentCount(Register VReg) const;
203+
200204
// For vectors or scalars of booleans, integers and floats, return the scalar
201205
// type's bitwidth. Otherwise calls llvm_unreachable().
202206
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,10 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
443443
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
444444
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
445445

446+
// SPV_INTEL_bfloat16_conversion
447+
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
448+
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
449+
446450
// 3.42.12 Composite Instructions
447451

448452
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,13 @@ void addInstrRequirements(const MachineInstr &MI,
11101110
case SPIRV::OpAtomicFMaxEXT:
11111111
AddAtomicFloatRequirements(MI, Reqs, ST);
11121112
break;
1113+
case SPIRV::OpConvertBF16ToFINTEL:
1114+
case SPIRV::OpConvertFToBF16INTEL:
1115+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
1116+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
1117+
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
1118+
}
1119+
break;
11131120
case SPIRV::OpVariableLengthArrayINTEL:
11141121
case SPIRV::OpSaveMemoryINTEL:
11151122
case SPIRV::OpRestoreMemoryINTEL:

llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
8686
"Allows to use the LinkOnceODR linkage type that is to let "
8787
"a function or global variable to be merged with other functions "
8888
"or global variables of the same name when linkage occurs."),
89+
clEnumValN(SPIRV::Extension::SPV_INTEL_bfloat16_conversion,
90+
"SPV_INTEL_bfloat16_conversion",
91+
"Adds instructions to convert between single-precision "
92+
"32-bit floating-point values and 16-bit bfloat16 values."),
8993
clEnumValN(SPIRV::Extension::SPV_KHR_subgroup_rotate,
9094
"SPV_KHR_subgroup_rotate",
9195
"Adds a new instruction that enables rotating values across "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
297297
defm SPV_INTEL_optnone : ExtensionOperand<103>;
298298
defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
299299
defm SPV_INTEL_variable_length_array : ExtensionOperand<105>;
300+
defm SPV_INTEL_bfloat16_conversion : ExtensionOperand<106>;
300301

301302
//===----------------------------------------------------------------------===//
302303
// Multiclass used to define Capabilities enum values and at the same time
@@ -466,6 +467,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
466467
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
467468
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
468469
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
470+
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
469471

470472
//===----------------------------------------------------------------------===//
471473
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)

0 commit comments

Comments
 (0)