@@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics {
1134
1134
if (FusedInsts.count (Inst))
1135
1135
continue ;
1136
1136
1137
- IRBuilder<> Builder (Inst);
1138
-
1139
1137
const ShapeInfo &SI = ShapeMap.at (Inst);
1140
1138
1141
1139
Value *Op1;
1142
1140
Value *Op2;
1141
+ MatrixTy Result;
1143
1142
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1144
- VisitBinaryOperator (BinOp, SI);
1143
+ Result = VisitBinaryOperator (BinOp, SI);
1145
1144
else if (auto *Cast = dyn_cast<CastInst>(Inst))
1146
- VisitCastInstruction (Cast, SI);
1145
+ Result = VisitCastInstruction (Cast, SI);
1147
1146
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1148
- VisitUnaryOperator (UnOp, SI);
1149
- else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
1150
- VisitIntrinsicInst (Intr, SI);
1147
+ Result = VisitUnaryOperator (UnOp, SI);
1148
+ else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1149
+ Result = VisitIntrinsicInst (Intr, SI);
1151
1150
else if (match (Inst, m_Load (m_Value (Op1))))
1152
- VisitLoad (cast<LoadInst>(Inst), SI, Op1, Builder );
1151
+ Result = VisitLoad (cast<LoadInst>(Inst), SI, Op1);
1153
1152
else if (match (Inst, m_Store (m_Value (Op1), m_Value (Op2))))
1154
- VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2, Builder );
1153
+ Result = VisitStore (cast<StoreInst>(Inst), SI, Op1, Op2);
1155
1154
else
1156
1155
continue ;
1156
+
1157
+ IRBuilder<> Builder (Inst);
1158
+ finalizeLowering (Inst, Result, Builder);
1157
1159
Changed = true ;
1158
1160
}
1159
1161
@@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics {
1193
1195
}
1194
1196
1195
1197
// / Replace intrinsic calls.
1196
- void VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &Shape) {
1197
- switch (Inst->getIntrinsicID ()) {
1198
+ MatrixTy VisitIntrinsicInst (IntrinsicInst *Inst, const ShapeInfo &SI) {
1199
+ assert (Inst->getCalledFunction () &&
1200
+ Inst->getCalledFunction ()->isIntrinsic ());
1201
+
1202
+ switch (Inst->getCalledFunction ()->getIntrinsicID ()) {
1198
1203
case Intrinsic::matrix_multiply:
1199
- LowerMultiply (Inst);
1200
- return ;
1204
+ return LowerMultiply (Inst);
1201
1205
case Intrinsic::matrix_transpose:
1202
- LowerTranspose (Inst);
1203
- return ;
1206
+ return LowerTranspose (Inst);
1204
1207
case Intrinsic::matrix_column_major_load:
1205
- LowerColumnMajorLoad (Inst);
1206
- return ;
1208
+ return LowerColumnMajorLoad (Inst);
1207
1209
case Intrinsic::matrix_column_major_store:
1208
- LowerColumnMajorStore (Inst);
1209
- return ;
1210
+ return LowerColumnMajorStore (Inst);
1210
1211
case Intrinsic::abs:
1211
1212
case Intrinsic::fabs: {
1212
1213
IRBuilder<> Builder (Inst);
1213
1214
MatrixTy Result;
1214
- MatrixTy M = getMatrix (Inst->getOperand (0 ), Shape , Builder);
1215
+ MatrixTy M = getMatrix (Inst->getOperand (0 ), SI , Builder);
1215
1216
Builder.setFastMathFlags (getFastMathFlags (Inst));
1216
1217
1217
1218
for (auto &Vector : M.vectors ()) {
@@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics {
1229
1230
}
1230
1231
}
1231
1232
1232
- finalizeLowering (Inst,
1233
- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234
- Result.getNumVectors ()),
1235
- Builder);
1236
- return ;
1233
+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
1234
+ Result.getNumVectors ());
1237
1235
}
1238
1236
default :
1239
- llvm_unreachable (
1240
- " only intrinsics supporting shape info should be seen here" );
1237
+ break ;
1241
1238
}
1239
+ llvm_unreachable (
1240
+ " only intrinsics supporting shape info should be seen here" );
1242
1241
}
1243
1242
1244
1243
// / Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics {
1304
1303
}
1305
1304
1306
1305
// / Lower a load instruction with shape information.
1307
- void LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride ,
1308
- bool IsVolatile, ShapeInfo Shape) {
1306
+ MatrixTy LowerLoad (Instruction *Inst, Value *Ptr, MaybeAlign Align,
1307
+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1309
1308
IRBuilder<> Builder (Inst);
1310
- finalizeLowering (Inst,
1311
- loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile,
1312
- Shape, Builder),
1313
- Builder);
1309
+ return loadMatrix (Inst->getType (), Ptr, Align, Stride, IsVolatile, Shape,
1310
+ Builder);
1314
1311
}
1315
1312
1316
1313
// / Lowers llvm.matrix.column.major.load.
1317
1314
// /
1318
1315
// / The intrinsic loads a matrix from memory using a stride between columns.
1319
- void LowerColumnMajorLoad (CallInst *Inst) {
1316
+ MatrixTy LowerColumnMajorLoad (CallInst *Inst) {
1320
1317
assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1321
1318
" Intrinsic only supports column-major layout!" );
1322
1319
Value *Ptr = Inst->getArgOperand (0 );
1323
1320
Value *Stride = Inst->getArgOperand (1 );
1324
- LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1325
- cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1326
- {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
1321
+ return LowerLoad (Inst, Ptr, Inst->getParamAlign (0 ), Stride,
1322
+ cast<ConstantInt>(Inst->getArgOperand (2 ))->isOne (),
1323
+ {Inst->getArgOperand (3 ), Inst->getArgOperand (4 )});
1327
1324
}
1328
1325
1329
1326
// / Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics {
1366
1363
}
1367
1364
1368
1365
// / Lower a store instruction with shape information.
1369
- void LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1370
- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1366
+ MatrixTy LowerStore (Instruction *Inst, Value *Matrix, Value *Ptr,
1367
+ MaybeAlign A, Value *Stride, bool IsVolatile,
1368
+ ShapeInfo Shape) {
1371
1369
IRBuilder<> Builder (Inst);
1372
1370
auto StoreVal = getMatrix (Matrix, Shape, Builder);
1373
- finalizeLowering (Inst,
1374
- storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride,
1375
- IsVolatile, Builder),
1376
- Builder);
1371
+ return storeMatrix (Matrix->getType (), StoreVal, Ptr, A, Stride, IsVolatile,
1372
+ Builder);
1377
1373
}
1378
1374
1379
1375
// / Lowers llvm.matrix.column.major.store.
1380
1376
// /
1381
1377
// / The intrinsic store a matrix back memory using a stride between columns.
1382
- void LowerColumnMajorStore (CallInst *Inst) {
1378
+ MatrixTy LowerColumnMajorStore (CallInst *Inst) {
1383
1379
assert (MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1384
1380
" Intrinsic only supports column-major layout!" );
1385
1381
Value *Matrix = Inst->getArgOperand (0 );
1386
1382
Value *Ptr = Inst->getArgOperand (1 );
1387
1383
Value *Stride = Inst->getArgOperand (2 );
1388
- LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1389
- cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1390
- {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
1384
+ return LowerStore (Inst, Matrix, Ptr, Inst->getParamAlign (1 ), Stride,
1385
+ cast<ConstantInt>(Inst->getArgOperand (3 ))->isOne (),
1386
+ {Inst->getArgOperand (4 ), Inst->getArgOperand (5 )});
1391
1387
}
1392
1388
1393
1389
// Set elements I..I+NumElts-1 to Block
@@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics {
2162
2158
}
2163
2159
2164
2160
// / Lowers llvm.matrix.multiply.
2165
- void LowerMultiply (CallInst *MatMul) {
2161
+ MatrixTy LowerMultiply (CallInst *MatMul) {
2166
2162
IRBuilder<> Builder (MatMul);
2167
2163
auto *EltType = cast<FixedVectorType>(MatMul->getType ())->getElementType ();
2168
2164
ShapeInfo LShape (MatMul->getArgOperand (2 ), MatMul->getArgOperand (3 ));
@@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics {
2184
2180
2185
2181
emitMatrixMultiply (Result, Lhs, Rhs, Builder, false , false ,
2186
2182
getFastMathFlags (MatMul));
2187
- finalizeLowering (MatMul, Result, Builder) ;
2183
+ return Result;
2188
2184
}
2189
2185
2190
2186
// / Lowers llvm.matrix.transpose.
2191
- void LowerTranspose (CallInst *Inst) {
2187
+ MatrixTy LowerTranspose (CallInst *Inst) {
2192
2188
MatrixTy Result;
2193
2189
IRBuilder<> Builder (Inst);
2194
2190
Value *InputVal = Inst->getArgOperand (0 );
@@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics {
2218
2214
// TODO: Improve estimate of operations needed for transposes. Currently we
2219
2215
// just count the insertelement/extractelement instructions, but do not
2220
2216
// account for later simplifications/combines.
2221
- finalizeLowering (
2222
- Inst,
2223
- Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2224
- .addNumExposedTransposes (1 ),
2225
- Builder);
2217
+ return Result.addNumComputeOps (2 * ArgShape.NumRows * ArgShape.NumColumns )
2218
+ .addNumExposedTransposes (1 );
2226
2219
}
2227
2220
2228
2221
// / Lower load instructions.
2229
- void VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2230
- IRBuilder<> & Builder) {
2231
- LowerLoad (Inst, Ptr, Inst->getAlign (), Builder. getInt64 (SI. getStride () ),
2232
- Inst->isVolatile (), SI);
2222
+ MatrixTy VisitLoad (LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
2223
+ IRBuilder<> Builder (Inst);
2224
+ return LowerLoad (Inst, Ptr, Inst->getAlign (),
2225
+ Builder. getInt64 (SI. getStride ()), Inst->isVolatile (), SI);
2233
2226
}
2234
2227
2235
- void VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2236
- Value *Ptr, IRBuilder<> &Builder) {
2237
- LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2238
- Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2228
+ MatrixTy VisitStore (StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2229
+ Value *Ptr) {
2230
+ IRBuilder<> Builder (Inst);
2231
+ return LowerStore (Inst, StoredVal, Ptr, Inst->getAlign (),
2232
+ Builder.getInt64 (SI.getStride ()), Inst->isVolatile (), SI);
2239
2233
}
2240
2234
2241
2235
// / Lower binary operators.
2242
- void VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
2236
+ MatrixTy VisitBinaryOperator (BinaryOperator *Inst, const ShapeInfo &SI) {
2243
2237
Value *Lhs = Inst->getOperand (0 );
2244
2238
Value *Rhs = Inst->getOperand (1 );
2245
2239
@@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics {
2258
2252
Result.addVector (Builder.CreateBinOp (Inst->getOpcode (), A.getVector (I),
2259
2253
B.getVector (I)));
2260
2254
2261
- finalizeLowering (Inst,
2262
- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2263
- Result.getNumVectors ()),
2264
- Builder);
2255
+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2256
+ Result.getNumVectors ());
2265
2257
}
2266
2258
2267
2259
// / Lower unary operators.
2268
- void VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
2260
+ MatrixTy VisitUnaryOperator (UnaryOperator *Inst, const ShapeInfo &SI) {
2269
2261
Value *Op = Inst->getOperand (0 );
2270
2262
2271
2263
IRBuilder<> Builder (Inst);
@@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics {
2288
2280
for (unsigned I = 0 ; I < SI.getNumVectors (); ++I)
2289
2281
Result.addVector (BuildVectorOp (M.getVector (I)));
2290
2282
2291
- finalizeLowering (Inst,
2292
- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2293
- Result.getNumVectors ()),
2294
- Builder);
2283
+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2284
+ Result.getNumVectors ());
2295
2285
}
2296
2286
2297
2287
// / Lower cast instructions.
2298
- void VisitCastInstruction (CastInst *Inst, const ShapeInfo &Shape) {
2288
+ MatrixTy VisitCastInstruction (CastInst *Inst, const ShapeInfo &Shape) {
2299
2289
Value *Op = Inst->getOperand (0 );
2300
2290
2301
2291
IRBuilder<> Builder (Inst);
@@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics {
2312
2302
for (auto &Vector : M.vectors ())
2313
2303
Result.addVector (Builder.CreateCast (Inst->getOpcode (), Vector, NewVTy));
2314
2304
2315
- finalizeLowering (Inst,
2316
- Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2317
- Result.getNumVectors ()),
2318
- Builder);
2305
+ return Result.addNumComputeOps (getNumOps (Result.getVectorTy ()) *
2306
+ Result.getNumVectors ());
2319
2307
}
2320
2308
2321
2309
// / Helper to linearize a matrix expression tree into a string. Currently
0 commit comments