Skip to content

Commit 36c00ce

Browse files
jcranmer-intelPavel V Chupin
authored andcommitted
Replace uses of Type::getPointerElementType in SPIRVReader.
Original commit: KhronosGroup/SPIRV-LLVM-Translator@07a7cca
1 parent ae2218e commit 36c00ce

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,8 @@ void SPIRVToLLVM::addMemAliasMetadata(Instruction *I, SPIRVId AliasListId,
12831283
I->setMetadata(AliasMDKind, MDAliasListMap[AliasListId]);
12841284
}
12851285

1286-
void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
1286+
void SPIRVToLLVM::transFunctionPointerCallArgumentAttributes(
1287+
SPIRVValue *BV, CallInst *CI, SPIRVTypeFunction *CalledFnTy) {
12871288
std::vector<SPIRVDecorate const *> ArgumentAttributes =
12881289
BV->getDecorations(internal::DecorationArgumentAttributeINTEL);
12891290

@@ -1296,8 +1297,8 @@ void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
12961297
auto LlvmAttr =
12971298
Attribute::isTypeAttrKind(LlvmAttrKind)
12981299
? Attribute::get(CI->getContext(), LlvmAttrKind,
1299-
cast<PointerType>(CI->getOperand(ArgNo)->getType())
1300-
->getPointerElementType())
1300+
transType(CalledFnTy->getParameterType(ArgNo)
1301+
->getPointerElementType()))
13011302
: Attribute::get(CI->getContext(), LlvmAttrKind);
13021303
CI->addParamAttr(ArgNo, LlvmAttr);
13031304
}
@@ -1733,7 +1734,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
17331734
// A ptr.annotation may have been generated for the source variable.
17341735
replaceOperandWithAnnotationIntrinsicCallResult(V);
17351736

1736-
Type *Ty = V->getType()->getPointerElementType();
1737+
Type *Ty = transType(BL->getType());
17371738
LoadInst *LI = nullptr;
17381739
uint64_t AlignValue = BL->SPIRVMemoryAccess::getAlignment();
17391740
if (0 == AlignValue) {
@@ -2083,7 +2084,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
20832084
case OpInBoundsPtrAccessChain: {
20842085
auto AC = static_cast<SPIRVAccessChainBase *>(BV);
20852086
auto Base = transValue(AC->getBase(), F, BB);
2086-
Type *BaseTy = cast<PointerType>(Base->getType())->getPointerElementType();
2087+
Type *BaseTy = transType(AC->getBase()->getType()->getPointerElementType());
20872088
auto Index = transValue(AC->getIndices(), F, BB);
20882089
if (!AC->hasPtrIndex())
20892090
Index.insert(Index.begin(), getInt32(M, 0));
@@ -2237,11 +2238,13 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
22372238
case OpFunctionPointerCallINTEL: {
22382239
SPIRVFunctionPointerCallINTEL *BC =
22392240
static_cast<SPIRVFunctionPointerCallINTEL *>(BV);
2240-
auto V = transValue(BC->getCalledValue(), F, BB);
2241-
auto Call = CallInst::Create(
2242-
cast<FunctionType>(V->getType()->getPointerElementType()), V,
2243-
transValue(BC->getArgumentValues(), F, BB), BC->getName(), BB);
2244-
transFunctionPointerCallArgumentAttributes(BV, Call);
2241+
auto *V = transValue(BC->getCalledValue(), F, BB);
2242+
auto *SpirvFnTy = BC->getCalledValue()->getType()->getPointerElementType();
2243+
auto *FnTy = cast<FunctionType>(transType(SpirvFnTy));
2244+
auto *Call = CallInst::Create(
2245+
FnTy, V, transValue(BC->getArgumentValues(), F, BB), BC->getName(), BB);
2246+
transFunctionPointerCallArgumentAttributes(
2247+
BV, Call, static_cast<SPIRVTypeFunction *>(SpirvFnTy));
22452248
// Assuming we are calling a regular device function
22462249
Call->setCallingConv(CallingConv::SPIR_FUNC);
22472250
// Don't set attributes, because at translation time we don't know which
@@ -2404,7 +2407,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
24042407
else {
24052408
IID = Intrinsic::ptr_annotation;
24062409
auto *PtrTy = dyn_cast<PointerType>(Ty);
2407-
if (PtrTy && isa<IntegerType>(PtrTy->getPointerElementType()))
2410+
if (PtrTy &&
2411+
(PtrTy->isOpaque() ||
2412+
isa<IntegerType>(PtrTy->getNonOpaquePointerElementType())))
24082413
RetTy = PtrTy;
24092414
// Whether a struct or a pointer to some other type,
24102415
// bitcast to i8*
@@ -2812,10 +2817,8 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
28122817
Type *AttrTy = nullptr;
28132818
switch (LLVMKind) {
28142819
case Attribute::AttrKind::ByVal:
2815-
AttrTy = cast<PointerType>(I->getType())->getPointerElementType();
2816-
break;
28172820
case Attribute::AttrKind::StructRet:
2818-
AttrTy = cast<PointerType>(I->getType())->getPointerElementType();
2821+
AttrTy = transType(BA->getType()->getPointerElementType());
28192822
break;
28202823
default:
28212824
break; // do nothing

llvm-spirv/lib/SPIRV/SPIRVReader.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ class SPIRVToLLVM {
243243
void transMemAliasingINTELDecorations(SPIRVValue *BV, Value *V);
244244
void transVarDecorationsToMetadata(SPIRVValue *BV, Value *V);
245245
void transFunctionDecorationsToMetadata(SPIRVFunction *BF, Function *F);
246+
void
247+
transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI,
248+
SPIRVTypeFunction *CalledFnTy);
246249
}; // class SPIRVToLLVM
247250

248251
} // namespace SPIRV

0 commit comments

Comments
 (0)