@@ -1057,17 +1057,21 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1057
1057
return call;
1058
1058
}
1059
1059
1060
- class ControlBarrierPattern
1061
- : public SPIRVToLLVMConversion<spirv::ControlBarrierOp > {
1060
+ template < typename BarrierOpTy>
1061
+ class ControlBarrierPattern : public SPIRVToLLVMConversion <BarrierOpTy > {
1062
1062
public:
1063
- using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
1063
+ using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064
+
1065
+ using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1066
+
1067
+ static constexpr StringRef getFuncName ();
1064
1068
1065
1069
LogicalResult
1066
- matchAndRewrite (spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1070
+ matchAndRewrite (BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1067
1071
ConversionPatternRewriter &rewriter) const override {
1068
- constexpr StringLiteral funcName = " _Z22__spirv_ControlBarrieriii " ;
1072
+ constexpr StringRef funcName = getFuncName () ;
1069
1073
Operation *symbolTable =
1070
- controlBarrierOp->getParentWithTrait <OpTrait::SymbolTable>();
1074
+ controlBarrierOp->template getParentWithTrait <OpTrait::SymbolTable>();
1071
1075
1072
1076
Type i32 = rewriter.getI32Type ();
1073
1077
@@ -1266,6 +1270,24 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1266
1270
}
1267
1271
};
1268
1272
1273
+ template <>
1274
+ constexpr StringRef
1275
+ ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276
+ return " _Z22__spirv_ControlBarrieriii" ;
1277
+ }
1278
+
1279
+ template <>
1280
+ constexpr StringRef
1281
+ ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282
+ return " _Z33__spirv_ControlBarrierArriveINTELiii" ;
1283
+ }
1284
+
1285
+ template <>
1286
+ constexpr StringRef
1287
+ ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288
+ return " _Z31__spirv_ControlBarrierWaitINTELiii" ;
1289
+ }
1290
+
1269
1291
// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1270
1292
// / should be reachable for conversion to succeed. The structure of the loop in
1271
1293
// / LLVM dialect will be the following:
@@ -1899,7 +1921,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
1899
1921
ReturnPattern, ReturnValuePattern,
1900
1922
1901
1923
// Barrier ops
1902
- ControlBarrierPattern,
1924
+ ControlBarrierPattern<spirv::ControlBarrierOp>,
1925
+ ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926
+ ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1903
1927
1904
1928
// Group reduction operations
1905
1929
GroupReducePattern<spirv::GroupIAddOp>,
0 commit comments