@@ -1283,7 +1283,8 @@ void SPIRVToLLVM::addMemAliasMetadata(Instruction *I, SPIRVId AliasListId,
1283
1283
I->setMetadata (AliasMDKind, MDAliasListMap[AliasListId]);
1284
1284
}
1285
1285
1286
- void transFunctionPointerCallArgumentAttributes (SPIRVValue *BV, CallInst *CI) {
1286
+ void SPIRVToLLVM::transFunctionPointerCallArgumentAttributes (
1287
+ SPIRVValue *BV, CallInst *CI, SPIRVTypeFunction *CalledFnTy) {
1287
1288
std::vector<SPIRVDecorate const *> ArgumentAttributes =
1288
1289
BV->getDecorations (internal::DecorationArgumentAttributeINTEL);
1289
1290
@@ -1296,8 +1297,8 @@ void transFunctionPointerCallArgumentAttributes(SPIRVValue *BV, CallInst *CI) {
1296
1297
auto LlvmAttr =
1297
1298
Attribute::isTypeAttrKind (LlvmAttrKind)
1298
1299
? Attribute::get (CI->getContext (), LlvmAttrKind,
1299
- cast<PointerType>(CI-> getOperand (ArgNo)-> getType () )
1300
- ->getPointerElementType ())
1300
+ transType (CalledFnTy-> getParameterType (ArgNo)
1301
+ ->getPointerElementType () ))
1301
1302
: Attribute::get (CI->getContext (), LlvmAttrKind);
1302
1303
CI->addParamAttr (ArgNo, LlvmAttr);
1303
1304
}
@@ -1733,7 +1734,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
1733
1734
// A ptr.annotation may have been generated for the source variable.
1734
1735
replaceOperandWithAnnotationIntrinsicCallResult (V);
1735
1736
1736
- Type *Ty = V ->getType ()-> getPointerElementType ( );
1737
+ Type *Ty = transType (BL ->getType ());
1737
1738
LoadInst *LI = nullptr ;
1738
1739
uint64_t AlignValue = BL->SPIRVMemoryAccess ::getAlignment ();
1739
1740
if (0 == AlignValue) {
@@ -2083,7 +2084,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2083
2084
case OpInBoundsPtrAccessChain: {
2084
2085
auto AC = static_cast <SPIRVAccessChainBase *>(BV);
2085
2086
auto Base = transValue (AC->getBase (), F, BB);
2086
- Type *BaseTy = cast<PointerType>(Base ->getType ()) ->getPointerElementType ();
2087
+ Type *BaseTy = transType (AC-> getBase () ->getType ()->getPointerElementType () );
2087
2088
auto Index = transValue (AC->getIndices (), F, BB);
2088
2089
if (!AC->hasPtrIndex ())
2089
2090
Index.insert (Index.begin (), getInt32 (M, 0 ));
@@ -2237,11 +2238,13 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2237
2238
case OpFunctionPointerCallINTEL: {
2238
2239
SPIRVFunctionPointerCallINTEL *BC =
2239
2240
static_cast <SPIRVFunctionPointerCallINTEL *>(BV);
2240
- auto V = transValue (BC->getCalledValue (), F, BB);
2241
- auto Call = CallInst::Create (
2242
- cast<FunctionType>(V->getType ()->getPointerElementType ()), V,
2243
- transValue (BC->getArgumentValues (), F, BB), BC->getName (), BB);
2244
- transFunctionPointerCallArgumentAttributes (BV, Call);
2241
+ auto *V = transValue (BC->getCalledValue (), F, BB);
2242
+ auto *SpirvFnTy = BC->getCalledValue ()->getType ()->getPointerElementType ();
2243
+ auto *FnTy = cast<FunctionType>(transType (SpirvFnTy));
2244
+ auto *Call = CallInst::Create (
2245
+ FnTy, V, transValue (BC->getArgumentValues (), F, BB), BC->getName (), BB);
2246
+ transFunctionPointerCallArgumentAttributes (
2247
+ BV, Call, static_cast <SPIRVTypeFunction *>(SpirvFnTy));
2245
2248
// Assuming we are calling a regular device function
2246
2249
Call->setCallingConv (CallingConv::SPIR_FUNC);
2247
2250
// Don't set attributes, because at translation time we don't know which
@@ -2404,7 +2407,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
2404
2407
else {
2405
2408
IID = Intrinsic::ptr_annotation;
2406
2409
auto *PtrTy = dyn_cast<PointerType>(Ty);
2407
- if (PtrTy && isa<IntegerType>(PtrTy->getPointerElementType ()))
2410
+ if (PtrTy &&
2411
+ (PtrTy->isOpaque () ||
2412
+ isa<IntegerType>(PtrTy->getNonOpaquePointerElementType ())))
2408
2413
RetTy = PtrTy;
2409
2414
// Whether a struct or a pointer to some other type,
2410
2415
// bitcast to i8*
@@ -2812,10 +2817,8 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
2812
2817
Type *AttrTy = nullptr ;
2813
2818
switch (LLVMKind) {
2814
2819
case Attribute::AttrKind::ByVal:
2815
- AttrTy = cast<PointerType>(I->getType ())->getPointerElementType ();
2816
- break ;
2817
2820
case Attribute::AttrKind::StructRet:
2818
- AttrTy = cast<PointerType>(I ->getType ()) ->getPointerElementType ();
2821
+ AttrTy = transType (BA ->getType ()->getPointerElementType () );
2819
2822
break ;
2820
2823
default :
2821
2824
break ; // do nothing
0 commit comments