Skip to content

Commit 2557995

Browse files
MrSidimssys-ce-bb
authored andcommitted
Implement SPV_KHR_bfloat16 extension (#3099)
The extension add translation from LLVM's bfloat type to OpTypeFloat %width% 16 %fp encoding% BFloat16KHR Mangling follows LLVM's rules for the type. Spec PR: KhronosGroup/SPIRV-Registry#323 --------- Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com> Co-authored-by: Aziz, Michael <michael.aziz@intel.com> Original commit: KhronosGroup/SPIRV-LLVM-Translator@d3b7a12ee9b8f2b
1 parent ce22ee1 commit 2557995

File tree

19 files changed

+216
-45
lines changed

19 files changed

+216
-45
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ EXT(SPV_INTEL_maximum_registers)
7575
EXT(SPV_INTEL_bindless_images)
7676
EXT(SPV_INTEL_2d_block_io)
7777
EXT(SPV_INTEL_subgroup_matrix_multiply_accumulate)
78+
EXT(SPV_KHR_bfloat16)

llvm-spirv/lib/SPIRV/Mangler/ManglingUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const char *PrimitiveNames[PRIMITIVE_NUM] = {
2828
"half",
2929
"float",
3030
"double",
31+
"__bf16",
3132
"void",
3233
"...",
3334
"image1d_ro_t",
@@ -105,6 +106,7 @@ const char *MangledTypes[PRIMITIVE_NUM] = {
105106
"Dh", // HALF
106107
"f", // FLOAT
107108
"d", // DOUBLE
109+
"u6__bf16", // __BF16
108110
"v", // VOID
109111
"z", // VarArg
110112
"14ocl_image1d_ro", // PRIMITIVE_IMAGE1D_RO_T
@@ -197,6 +199,7 @@ static const SPIRversion PrimitiveSupportedVersions[PRIMITIVE_NUM] = {
197199
SPIR12, // HALF
198200
SPIR12, // FLOAT
199201
SPIR12, // DOUBLE
202+
SPIR12, // __BF16
200203
SPIR12, // VOID
201204
SPIR12, // VarArg
202205
SPIR12, // PRIMITIVE_IMAGE1D_RO_T

llvm-spirv/lib/SPIRV/Mangler/ParameterType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum TypePrimitiveEnum {
4545
PRIMITIVE_HALF,
4646
PRIMITIVE_FLOAT,
4747
PRIMITIVE_DOUBLE,
48+
PRIMITIVE_BFLOAT,
4849
PRIMITIVE_VOID,
4950
PRIMITIVE_VAR_ARG,
5051
PRIMITIVE_STRUCT_FIRST,

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
316316
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
317317
switch (T->getFloatBitWidth()) {
318318
case 16:
319+
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
320+
return Type::getBFloatTy(*Context);
319321
return Type::getHalfTy(*Context);
320322
case 32:
321323
return Type::getFloatTy(*Context);
@@ -1485,7 +1487,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
14851487
const llvm::fltSemantics *FS = nullptr;
14861488
switch (BT->getFloatBitWidth()) {
14871489
case 16:
1488-
FS = &APFloat::IEEEhalf();
1490+
FS =
1491+
(BT->isTypeFloat(16, FPEncodingBFloat16KHR) ? &APFloat::BFloat()
1492+
: &APFloat::IEEEhalf());
14891493
break;
14901494
case 32:
14911495
FS = &APFloat::IEEEsingle();

llvm-spirv/lib/SPIRV/SPIRVUtil.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,8 @@ static SPIR::RefParamType transTypeDesc(Type *Ty,
13401340
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_FLOAT));
13411341
if (Ty->isDoubleTy())
13421342
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_DOUBLE));
1343+
if (Ty->isBFloatTy())
1344+
return SPIR::RefParamType(new SPIR::PrimitiveType(SPIR::PRIMITIVE_BFLOAT));
13431345
if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) {
13441346
return SPIR::RefParamType(new SPIR::VectorType(
13451347
transTypeDesc(VecTy->getElementType(), Info), VecTy->getNumElements()));

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,16 @@ SPIRVType *LLVMToSPIRVBase::transType(Type *T) {
401401
}
402402
}
403403

404-
// Emit error if type is bfloat. LLVM native bfloat type is not supported.
405-
BM->getErrorLog().checkError(!T->isBFloatTy(),
406-
SPIRVEC_UnsupportedLLVMBFloatType);
404+
if (T->isBFloatTy()) {
405+
BM->getErrorLog().checkError(
406+
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_bfloat16),
407+
SPIRVEC_RequiresExtension,
408+
"SPV_KHR_bfloat16\n"
409+
"NOTE: LLVM module contains bfloat type, translation of which "
410+
"requires this extension");
411+
return mapType(T, BM->addFloatType(16, FPEncodingBFloat16KHR));
412+
}
413+
407414
if (T->isFloatingPointTy())
408415
return mapType(T, BM->addFloatType(T->getPrimitiveSizeInBits()));
409416

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
223223
{CapabilityCooperativeMatrixKHR});
224224
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixOffsetInstructionsINTEL,
225225
{CapabilityCooperativeMatrixKHR});
226+
ADD_VEC_INIT(CapabilityBFloat16DotProductKHR, {CapabilityBFloat16TypeKHR});
227+
ADD_VEC_INIT(CapabilityBFloat16CooperativeMatrixKHR,
228+
{CapabilityBFloat16TypeKHR, CapabilityCooperativeMatrixKHR});
226229
}
227230

228231
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ _SPIRV_OP(UnspecifiedMemoryModel, "Unspecified Memory Model.")
2828
_SPIRV_OP(RepeatedMemoryModel, "Expects a single OpMemoryModel instruction.")
2929
_SPIRV_OP(UnsupportedVarArgFunction,
3030
"Variadic functions other than 'printf' are not supported in SPIR-V.")
31-
_SPIRV_OP(UnsupportedLLVMBFloatType,
32-
"LLVM bfloat type is not supported in SPIR-V.")
3331

3432
/* This is the last error code to have a maximum valid value to compare to */
3533
_SPIRV_OP(InternalMaxErrorCode, "Unknown error code")

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,18 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
778778
return VersionNumber::SPIRV_1_4;
779779
return VersionNumber::SPIRV_1_0;
780780
}
781+
SPIRVCapVec getRequiredCapability() const override {
782+
if (OpCode == OpDot) {
783+
const SPIRVType *OpTy = getValueType(Ops[0]);
784+
if (OpTy && OpTy->isTypeVector()) {
785+
OpTy = OpTy->getVectorComponentType();
786+
if (OpTy && OpTy->isTypeFloat(16, FPEncodingBFloat16KHR)) {
787+
return getVec(CapabilityBFloat16DotProductKHR);
788+
}
789+
}
790+
}
791+
return SPIRVInstruction::getRequiredCapability();
792+
}
781793
};
782794

783795
template <Op OC>

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ class SPIRVModuleImpl : public SPIRVModule {
253253
template <class T> T *addType(T *Ty);
254254
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) override;
255255
SPIRVTypeBool *addBoolType() override;
256-
SPIRVTypeFloat *addFloatType(unsigned BitWidth) override;
256+
SPIRVTypeFloat *addFloatType(unsigned BitWidth,
257+
unsigned FloatingPointEncoding) override;
257258
SPIRVTypeFunction *addFunctionType(SPIRVType *,
258259
const std::vector<SPIRVType *> &) override;
259260
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
@@ -577,7 +578,8 @@ class SPIRVModuleImpl : public SPIRVModule {
577578
SmallDenseMap<SPIRVStorageClassKind, SPIRVTypeUntypedPointerKHR *>
578579
UntypedPtrTyMap;
579580
SmallDenseMap<unsigned, SPIRVTypeInt *, 4> IntTypeMap;
580-
SmallDenseMap<unsigned, SPIRVTypeFloat *, 4> FloatTypeMap;
581+
SmallDenseMap<std::pair<unsigned, unsigned>, SPIRVTypeFloat *, 4>
582+
FloatTypeMap;
581583
SmallDenseMap<std::pair<unsigned, SPIRVType *>, SPIRVTypePointer *, 4>
582584
PointerTypeMap;
583585
std::unordered_map<unsigned, SPIRVConstant *> LiteralMap;
@@ -1007,12 +1009,14 @@ SPIRVTypeInt *SPIRVModuleImpl::addIntegerType(unsigned BitWidth) {
10071009
return addType(Ty);
10081010
}
10091011

1010-
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth) {
1011-
auto Loc = FloatTypeMap.find(BitWidth);
1012+
SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
1013+
unsigned FloatingPointEncoding) {
1014+
auto Desc = std::make_pair(BitWidth, FloatingPointEncoding);
1015+
auto Loc = FloatTypeMap.find(Desc);
10121016
if (Loc != FloatTypeMap.end())
10131017
return Loc->second;
1014-
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth);
1015-
FloatTypeMap[BitWidth] = Ty;
1018+
auto *Ty = new SPIRVTypeFloat(this, getId(), BitWidth, FloatingPointEncoding);
1019+
FloatTypeMap[Desc] = Ty;
10161020
return addType(Ty);
10171021
}
10181022

0 commit comments

Comments
 (0)