Skip to content

Commit ecc3bda

Browse files
[SPIR-V] Fix bitcast legalization/instruction selection in SPIR-V Backend (llvm#83139)
This PR is to fix a way how SPIR-V Backend describes legality of OpBitcast instruction and how it is validated on a step of instruction selection. Instead of checking a size of virtual registers (that makes no sense due to lack of guarantee of direct relations between size of virtual register and bit width associated with the type size), this PR allows to legalize OpBitcast without size check and postpones validation to the instruction selection step. As an example, let's consider the next example that was copied as is from a bigger test suite: ``` %355:id(s16) = G_BITCAST %301:id(s32) %303:id(s16) = ASSIGN_TYPE %355:id(s16), %349:type(s32) %644:fid(s32) = G_FMUL %645:fid, %646:fid %301:id(s32) = ASSIGN_TYPE %644:fid(s32), %40:type(s32) ``` Without the PR this leads to a crash with complains to an illegal bitcast, because %355 is s16 and %301 is s32. However, we must check not virtual registers in this case, but types of %355 and %301, i.e., %349:type(s32) and %40:type(s32), which are perfectly well compatible in a sense of OpBitcast in this case. In a test case that is a part of this PR OpBitcast is legal, being applied for `OpTypeInt 16` and `OpTypeFloat 16`, but would not be legalized without this PR due to virtual registers defined as having size 16 and 32.
1 parent 540d255 commit ecc3bda

File tree

5 files changed

+94
-14
lines changed

5 files changed

+94
-14
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,15 +834,49 @@ SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
834834
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
835835
}
836836

837-
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
837+
unsigned SPIRVGlobalRegistry::getNumScalarOrVectorTotalBitWidth(
838+
const SPIRVType *Type) const {
838839
assert(Type && "Invalid Type pointer");
840+
unsigned NumElements = 1;
839841
if (Type->getOpcode() == SPIRV::OpTypeVector) {
840-
auto EleTypeReg = Type->getOperand(1).getReg();
841-
Type = getSPIRVTypeForVReg(EleTypeReg);
842+
NumElements = static_cast<unsigned>(Type->getOperand(2).getImm());
843+
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
842844
}
843-
if (Type->getOpcode() == SPIRV::OpTypeInt)
844-
return Type->getOperand(2).getImm() != 0;
845-
llvm_unreachable("Attempting to get sign of non-integer type.");
845+
return Type->getOpcode() == SPIRV::OpTypeInt ||
846+
Type->getOpcode() == SPIRV::OpTypeFloat
847+
? NumElements * Type->getOperand(1).getImm()
848+
: 0;
849+
}
850+
851+
const SPIRVType *SPIRVGlobalRegistry::retrieveScalarOrVectorIntType(
852+
const SPIRVType *Type) const {
853+
if (Type && Type->getOpcode() == SPIRV::OpTypeVector)
854+
Type = getSPIRVTypeForVReg(Type->getOperand(1).getReg());
855+
return Type && Type->getOpcode() == SPIRV::OpTypeInt ? Type : nullptr;
856+
}
857+
858+
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
859+
const SPIRVType *IntType = retrieveScalarOrVectorIntType(Type);
860+
return IntType && IntType->getOperand(2).getImm() != 0;
861+
}
862+
863+
bool SPIRVGlobalRegistry::isBitcastCompatible(const SPIRVType *Type1,
864+
const SPIRVType *Type2) const {
865+
if (!Type1 || !Type2)
866+
return false;
867+
auto Op1 = Type1->getOpcode(), Op2 = Type2->getOpcode();
868+
// Ignore difference between <1.5 and >=1.5 protocol versions:
869+
// it's valid if either Result Type or Operand is a pointer, and the other
870+
// is a pointer, an integer scalar, or an integer vector.
871+
if (Op1 == SPIRV::OpTypePointer &&
872+
(Op2 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type2)))
873+
return true;
874+
if (Op2 == SPIRV::OpTypePointer &&
875+
(Op1 == SPIRV::OpTypePointer || retrieveScalarOrVectorIntType(Type1)))
876+
return true;
877+
unsigned Bits1 = getNumScalarOrVectorTotalBitWidth(Type1),
878+
Bits2 = getNumScalarOrVectorTotalBitWidth(Type2);
879+
return Bits1 > 0 && Bits1 == Bits2;
846880
}
847881

848882
SPIRV::StorageClass::StorageClass

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,19 @@ class SPIRVGlobalRegistry {
197197
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
198198
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
199199

200-
// For vectors or scalars of ints/floats, return the scalar type's bitwidth.
200+
// For vectors or scalars of booleans, integers and floats, return the scalar
201+
// type's bitwidth. Otherwise calls llvm_unreachable().
201202
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
202203

204+
// For vectors or scalars of integers and floats, return total bitwidth of the
205+
// argument. Otherwise returns 0.
206+
unsigned getNumScalarOrVectorTotalBitWidth(const SPIRVType *Type) const;
207+
208+
// Returns either pointer to integer type, that may be a type of vector
209+
// elements or an original type, or nullptr if the argument is niether
210+
// an integer scalar, nor an integer vector
211+
const SPIRVType *retrieveScalarOrVectorIntType(const SPIRVType *Type) const;
212+
203213
// For integer vectors or scalars, return whether the integers are signed.
204214
bool isScalarOrVectorSigned(const SPIRVType *Type) const;
205215

@@ -209,6 +219,11 @@ class SPIRVGlobalRegistry {
209219
// Return the number of bits SPIR-V pointers and size_t variables require.
210220
unsigned getPointerSize() const { return PointerSize; }
211221

222+
// Returns true if two types are defined and are compatible in a sense of
223+
// OpBitcast instruction
224+
bool isBitcastCompatible(const SPIRVType *Type1,
225+
const SPIRVType *Type2) const;
226+
212227
private:
213228
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
214229

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
9595
bool selectUnOp(Register ResVReg, const SPIRVType *ResType, MachineInstr &I,
9696
unsigned Opcode) const;
9797

98+
bool selectBitcast(Register ResVReg, const SPIRVType *ResType,
99+
MachineInstr &I) const;
100+
98101
bool selectLoad(Register ResVReg, const SPIRVType *ResType,
99102
MachineInstr &I) const;
100103
bool selectStore(MachineInstr &I) const;
@@ -452,7 +455,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
452455
case TargetOpcode::G_INTTOPTR:
453456
return selectUnOp(ResVReg, ResType, I, SPIRV::OpConvertUToPtr);
454457
case TargetOpcode::G_BITCAST:
455-
return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
458+
return selectBitcast(ResVReg, ResType, I);
456459
case TargetOpcode::G_ADDRSPACE_CAST:
457460
return selectAddrSpaceCast(ResVReg, ResType, I);
458461
case TargetOpcode::G_PTR_ADD: {
@@ -592,6 +595,16 @@ bool SPIRVInstructionSelector::selectUnOp(Register ResVReg,
592595
Opcode);
593596
}
594597

598+
bool SPIRVInstructionSelector::selectBitcast(Register ResVReg,
599+
const SPIRVType *ResType,
600+
MachineInstr &I) const {
601+
Register OpReg = I.getOperand(1).getReg();
602+
SPIRVType *OpType = OpReg.isValid() ? GR.getSPIRVTypeForVReg(OpReg) : nullptr;
603+
if (!GR.isBitcastCompatible(ResType, OpType))
604+
report_fatal_error("incompatible result and operand types in a bitcast");
605+
return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
606+
}
607+
595608
static SPIRV::Scope::Scope getScope(SyncScope::ID Ord) {
596609
switch (Ord) {
597610
case SyncScope::SingleThread:

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
200200

201201
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
202202

203-
getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
204-
typeInSet(0, allPtrsScalarsAndVectors),
205-
typeInSet(1, allPtrsScalarsAndVectors),
206-
LegalityPredicate(([=](const LegalityQuery &Query) {
207-
return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
208-
}))));
203+
getActionDefinitionsBuilder(G_BITCAST).legalIf(
204+
all(typeInSet(0, allPtrsScalarsAndVectors),
205+
typeInSet(1, allPtrsScalarsAndVectors)));
209206

210207
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211208

llvm/test/CodeGen/SPIRV/bitcast.ll

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-SPIRV-DAG: %[[#TyInt32:]] = OpTypeInt 32 0
5+
; CHECK-SPIRV-DAG: %[[#TyInt16:]] = OpTypeInt 16 0
6+
; CHECK-SPIRV-DAG: %[[#TyHalf:]] = OpTypeFloat 16
7+
; CHECK-SPIRV-DAG: %[[#Arg32:]] = OpFunctionParameter %[[#TyInt32]]
8+
; CHECK-SPIRV-DAG: %[[#Arg16:]] = OpUConvert %[[#TyInt16]] %[[#Arg32]]
9+
; CHECK-SPIRV-DAG: %[[#ValHalf:]] = OpBitcast %[[#TyHalf]] %8
10+
; CHECK-SPIRV-DAG: %[[#ValHalf2:]] = OpFMul %[[#TyHalf]] %[[#ValHalf]] %[[#ValHalf]]
11+
; CHECK-SPIRV-DAG: %[[#Res16:]] = OpBitcast %[[#TyInt16]] %[[#ValHalf2]]
12+
; CHECK-SPIRV-DAG: OpReturnValue %[[#Res16]]
13+
14+
define i16 @foo(i32 %arg) {
15+
entry:
16+
%op16 = trunc i32 %arg to i16
17+
%val = bitcast i16 %op16 to half
18+
%val2 = fmul half %val, %val
19+
%res = bitcast half %val2 to i16
20+
ret i16 %res
21+
}

0 commit comments

Comments
 (0)