Skip to content

Commit d1339a8

Browse files
jroelofstomtor
authored andcommitted
[Matrix] Hoist finalizeLowering into caller. NFC (llvm#143038)
1 parent 8b8d986 commit d1339a8

File tree

1 file changed

+66
-78
lines changed

1 file changed

+66
-78
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 66 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,26 +1134,28 @@ class LowerMatrixIntrinsics {
11341134
if (FusedInsts.count(Inst))
11351135
continue;
11361136

1137-
IRBuilder<> Builder(Inst);
1138-
11391137
const ShapeInfo &SI = ShapeMap.at(Inst);
11401138

11411139
Value *Op1;
11421140
Value *Op2;
1141+
MatrixTy Result;
11431142
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
1144-
VisitBinaryOperator(BinOp, SI);
1143+
Result = VisitBinaryOperator(BinOp, SI);
11451144
else if (auto *Cast = dyn_cast<CastInst>(Inst))
1146-
VisitCastInstruction(Cast, SI);
1145+
Result = VisitCastInstruction(Cast, SI);
11471146
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
1148-
VisitUnaryOperator(UnOp, SI);
1149-
else if (IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Inst))
1150-
VisitIntrinsicInst(Intr, SI);
1147+
Result = VisitUnaryOperator(UnOp, SI);
1148+
else if (auto *Intr = dyn_cast<IntrinsicInst>(Inst))
1149+
Result = VisitIntrinsicInst(Intr, SI);
11511150
else if (match(Inst, m_Load(m_Value(Op1))))
1152-
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
1151+
Result = VisitLoad(cast<LoadInst>(Inst), SI, Op1);
11531152
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
1154-
VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
1153+
Result = VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2);
11551154
else
11561155
continue;
1156+
1157+
IRBuilder<> Builder(Inst);
1158+
finalizeLowering(Inst, Result, Builder);
11571159
Changed = true;
11581160
}
11591161

@@ -1193,25 +1195,24 @@ class LowerMatrixIntrinsics {
11931195
}
11941196

11951197
/// Replace intrinsic calls.
1196-
void VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &Shape) {
1197-
switch (Inst->getIntrinsicID()) {
1198+
MatrixTy VisitIntrinsicInst(IntrinsicInst *Inst, const ShapeInfo &SI) {
1199+
assert(Inst->getCalledFunction() &&
1200+
Inst->getCalledFunction()->isIntrinsic());
1201+
1202+
switch (Inst->getCalledFunction()->getIntrinsicID()) {
11981203
case Intrinsic::matrix_multiply:
1199-
LowerMultiply(Inst);
1200-
return;
1204+
return LowerMultiply(Inst);
12011205
case Intrinsic::matrix_transpose:
1202-
LowerTranspose(Inst);
1203-
return;
1206+
return LowerTranspose(Inst);
12041207
case Intrinsic::matrix_column_major_load:
1205-
LowerColumnMajorLoad(Inst);
1206-
return;
1208+
return LowerColumnMajorLoad(Inst);
12071209
case Intrinsic::matrix_column_major_store:
1208-
LowerColumnMajorStore(Inst);
1209-
return;
1210+
return LowerColumnMajorStore(Inst);
12101211
case Intrinsic::abs:
12111212
case Intrinsic::fabs: {
12121213
IRBuilder<> Builder(Inst);
12131214
MatrixTy Result;
1214-
MatrixTy M = getMatrix(Inst->getOperand(0), Shape, Builder);
1215+
MatrixTy M = getMatrix(Inst->getOperand(0), SI, Builder);
12151216
Builder.setFastMathFlags(getFastMathFlags(Inst));
12161217

12171218
for (auto &Vector : M.vectors()) {
@@ -1229,16 +1230,14 @@ class LowerMatrixIntrinsics {
12291230
}
12301231
}
12311232

1232-
finalizeLowering(Inst,
1233-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1234-
Result.getNumVectors()),
1235-
Builder);
1236-
return;
1233+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
1234+
Result.getNumVectors());
12371235
}
12381236
default:
1239-
llvm_unreachable(
1240-
"only intrinsics supporting shape info should be seen here");
1237+
break;
12411238
}
1239+
llvm_unreachable(
1240+
"only intrinsics supporting shape info should be seen here");
12421241
}
12431242

12441243
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
@@ -1304,26 +1303,24 @@ class LowerMatrixIntrinsics {
13041303
}
13051304

13061305
/// Lower a load instruction with shape information.
1307-
void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1308-
bool IsVolatile, ShapeInfo Shape) {
1306+
MatrixTy LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align,
1307+
Value *Stride, bool IsVolatile, ShapeInfo Shape) {
13091308
IRBuilder<> Builder(Inst);
1310-
finalizeLowering(Inst,
1311-
loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
1312-
Shape, Builder),
1313-
Builder);
1309+
return loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, Shape,
1310+
Builder);
13141311
}
13151312

13161313
/// Lowers llvm.matrix.column.major.load.
13171314
///
13181315
/// The intrinsic loads a matrix from memory using a stride between columns.
1319-
void LowerColumnMajorLoad(CallInst *Inst) {
1316+
MatrixTy LowerColumnMajorLoad(CallInst *Inst) {
13201317
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13211318
"Intrinsic only supports column-major layout!");
13221319
Value *Ptr = Inst->getArgOperand(0);
13231320
Value *Stride = Inst->getArgOperand(1);
1324-
LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1325-
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1326-
{Inst->getArgOperand(3), Inst->getArgOperand(4)});
1321+
return LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
1322+
cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
1323+
{Inst->getArgOperand(3), Inst->getArgOperand(4)});
13271324
}
13281325

13291326
/// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
@@ -1366,28 +1363,27 @@ class LowerMatrixIntrinsics {
13661363
}
13671364

13681365
/// Lower a store instruction with shape information.
1369-
void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1370-
Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1366+
MatrixTy LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr,
1367+
MaybeAlign A, Value *Stride, bool IsVolatile,
1368+
ShapeInfo Shape) {
13711369
IRBuilder<> Builder(Inst);
13721370
auto StoreVal = getMatrix(Matrix, Shape, Builder);
1373-
finalizeLowering(Inst,
1374-
storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
1375-
IsVolatile, Builder),
1376-
Builder);
1371+
return storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, IsVolatile,
1372+
Builder);
13771373
}
13781374

13791375
/// Lowers llvm.matrix.column.major.store.
13801376
///
13811377
/// The intrinsic store a matrix back memory using a stride between columns.
1382-
void LowerColumnMajorStore(CallInst *Inst) {
1378+
MatrixTy LowerColumnMajorStore(CallInst *Inst) {
13831379
assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
13841380
"Intrinsic only supports column-major layout!");
13851381
Value *Matrix = Inst->getArgOperand(0);
13861382
Value *Ptr = Inst->getArgOperand(1);
13871383
Value *Stride = Inst->getArgOperand(2);
1388-
LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1389-
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1390-
{Inst->getArgOperand(4), Inst->getArgOperand(5)});
1384+
return LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
1385+
cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
1386+
{Inst->getArgOperand(4), Inst->getArgOperand(5)});
13911387
}
13921388

13931389
// Set elements I..I+NumElts-1 to Block
@@ -2162,7 +2158,7 @@ class LowerMatrixIntrinsics {
21622158
}
21632159

21642160
/// Lowers llvm.matrix.multiply.
2165-
void LowerMultiply(CallInst *MatMul) {
2161+
MatrixTy LowerMultiply(CallInst *MatMul) {
21662162
IRBuilder<> Builder(MatMul);
21672163
auto *EltType = cast<FixedVectorType>(MatMul->getType())->getElementType();
21682164
ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
@@ -2184,11 +2180,11 @@ class LowerMatrixIntrinsics {
21842180

21852181
emitMatrixMultiply(Result, Lhs, Rhs, Builder, false, false,
21862182
getFastMathFlags(MatMul));
2187-
finalizeLowering(MatMul, Result, Builder);
2183+
return Result;
21882184
}
21892185

21902186
/// Lowers llvm.matrix.transpose.
2191-
void LowerTranspose(CallInst *Inst) {
2187+
MatrixTy LowerTranspose(CallInst *Inst) {
21922188
MatrixTy Result;
21932189
IRBuilder<> Builder(Inst);
21942190
Value *InputVal = Inst->getArgOperand(0);
@@ -2218,28 +2214,26 @@ class LowerMatrixIntrinsics {
22182214
// TODO: Improve estimate of operations needed for transposes. Currently we
22192215
// just count the insertelement/extractelement instructions, but do not
22202216
// account for later simplifications/combines.
2221-
finalizeLowering(
2222-
Inst,
2223-
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2224-
.addNumExposedTransposes(1),
2225-
Builder);
2217+
return Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
2218+
.addNumExposedTransposes(1);
22262219
}
22272220

22282221
/// Lower load instructions.
2229-
void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
2230-
IRBuilder<> &Builder) {
2231-
LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
2232-
Inst->isVolatile(), SI);
2222+
MatrixTy VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr) {
2223+
IRBuilder<> Builder(Inst);
2224+
return LowerLoad(Inst, Ptr, Inst->getAlign(),
2225+
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
22332226
}
22342227

2235-
void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2236-
Value *Ptr, IRBuilder<> &Builder) {
2237-
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2238-
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
2228+
MatrixTy VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
2229+
Value *Ptr) {
2230+
IRBuilder<> Builder(Inst);
2231+
return LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
2232+
Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
22392233
}
22402234

22412235
/// Lower binary operators.
2242-
void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
2236+
MatrixTy VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
22432237
Value *Lhs = Inst->getOperand(0);
22442238
Value *Rhs = Inst->getOperand(1);
22452239

@@ -2258,14 +2252,12 @@ class LowerMatrixIntrinsics {
22582252
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
22592253
B.getVector(I)));
22602254

2261-
finalizeLowering(Inst,
2262-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2263-
Result.getNumVectors()),
2264-
Builder);
2255+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2256+
Result.getNumVectors());
22652257
}
22662258

22672259
/// Lower unary operators.
2268-
void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
2260+
MatrixTy VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
22692261
Value *Op = Inst->getOperand(0);
22702262

22712263
IRBuilder<> Builder(Inst);
@@ -2288,14 +2280,12 @@ class LowerMatrixIntrinsics {
22882280
for (unsigned I = 0; I < SI.getNumVectors(); ++I)
22892281
Result.addVector(BuildVectorOp(M.getVector(I)));
22902282

2291-
finalizeLowering(Inst,
2292-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2293-
Result.getNumVectors()),
2294-
Builder);
2283+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2284+
Result.getNumVectors());
22952285
}
22962286

22972287
/// Lower cast instructions.
2298-
void VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
2288+
MatrixTy VisitCastInstruction(CastInst *Inst, const ShapeInfo &Shape) {
22992289
Value *Op = Inst->getOperand(0);
23002290

23012291
IRBuilder<> Builder(Inst);
@@ -2312,10 +2302,8 @@ class LowerMatrixIntrinsics {
23122302
for (auto &Vector : M.vectors())
23132303
Result.addVector(Builder.CreateCast(Inst->getOpcode(), Vector, NewVTy));
23142304

2315-
finalizeLowering(Inst,
2316-
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2317-
Result.getNumVectors()),
2318-
Builder);
2305+
return Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
2306+
Result.getNumVectors());
23192307
}
23202308

23212309
/// Helper to linearize a matrix expression tree into a string. Currently

0 commit comments

Comments
 (0)