Skip to content

Commit 8584b21

Browse files
authored
Lower allreduce (#144716)
Adding lowering mesh.allreduce to mpi.allreduce. Minor restructuring to increase code reuse.
1 parent 7e77aae commit 8584b21

File tree

14 files changed

+482
-241
lines changed

14 files changed

+482
-241
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
905905
shard/partition sizes depend on the rank.
906906
}];
907907
let dependentDialects = [
908+
"affine::AffineDialect",
909+
"arith::ArithDialect",
908910
"memref::MemRefDialect",
909911
"mpi::MPIDialect",
910912
"scf::SCFDialect",

mlir/include/mlir/Dialect/MPI/IR/MPI.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/Dialect.h"
1313
#include "mlir/IR/OpDefinition.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "mlir/Interfaces/SideEffectInterfaces.h"
1516

1617
//===----------------------------------------------------------------------===//
1718
// MPIDialect

mlir/include/mlir/Dialect/MPI/IR/MPI.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
230230
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
231231
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;
232232

233-
def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
233+
def MPI_ReductionOpEnum : I32EnumAttr<"MPI_ReductionOpEnum", "MPI operation class", [
234234
MPI_OpNull,
235235
MPI_OpMax,
236236
MPI_OpMin,

mlir/include/mlir/Dialect/MPI/IR/MPIOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
include "mlir/Dialect/MPI/IR/MPI.td"
1313
include "mlir/Dialect/MPI/IR/MPITypes.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1415

1516
class MPI_Op<string mnemonic, list<Trait> traits = []>
1617
: Op<MPI_Dialect, mnemonic, traits>;
@@ -41,7 +42,7 @@ def MPI_InitOp : MPI_Op<"init", []> {
4142
// CommWorldOp
4243
//===----------------------------------------------------------------------===//
4344

44-
def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
45+
def MPI_CommWorldOp : MPI_Op<"comm_world", [Pure]> {
4546
let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`";
4647
let description = [{
4748
This operation returns the predefined MPI_COMM_WORLD communicator.
@@ -56,7 +57,7 @@ def MPI_CommWorldOp : MPI_Op<"comm_world", []> {
5657
// CommRankOp
5758
//===----------------------------------------------------------------------===//
5859

59-
def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
60+
def MPI_CommRankOp : MPI_Op<"comm_rank", [Pure]> {
6061
let summary = "Get the current rank, equivalent to "
6162
"`MPI_Comm_rank(comm, &rank)`";
6263
let description = [{
@@ -72,13 +73,14 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
7273
);
7374

7475
let assemblyFormat = "`(` $comm `)` attr-dict `:` type(results)";
76+
let hasCanonicalizer = 1;
7577
}
7678

7779
//===----------------------------------------------------------------------===//
7880
// CommSizeOp
7981
//===----------------------------------------------------------------------===//
8082

81-
def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
83+
def MPI_CommSizeOp : MPI_Op<"comm_size", [Pure]> {
8284
let summary = "Get the size of the group associated to the communicator, "
8385
"equivalent to `MPI_Comm_size(comm, &size)`";
8486
let description = [{
@@ -100,7 +102,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
100102
// CommSplitOp
101103
//===----------------------------------------------------------------------===//
102104

103-
def MPI_CommSplitOp : MPI_Op<"comm_split", []> {
105+
def MPI_CommSplitOp : MPI_Op<"comm_split", [Pure]> {
104106
let summary = "Partition the group associated with the given communicator into "
105107
"disjoint subgroups";
106108
let description = [{
@@ -281,7 +283,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
281283
let arguments = (
282284
ins AnyMemRef : $sendbuf,
283285
AnyMemRef : $recvbuf,
284-
MPI_OpClassEnum : $op,
286+
MPI_ReductionOpEnum : $op,
285287
MPI_Comm : $comm
286288
);
287289

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ void maybeInsertSourceShardingAnnotation(MeshSharding sharding,
212212
OpOperand &operand,
213213
OpBuilder &builder);
214214

215+
/// Converts a vector of OpFoldResults (ints) into vector of Values of the
216+
/// provided type.
217+
SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
218+
llvm::ArrayRef<int64_t> statics,
219+
ValueRange dynamics, Type type = Type());
215220
} // namespace mesh
216221
} // namespace mlir
217222

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,11 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
584584
```
585585
}];
586586
let arguments = !con(commonArgs, (ins
587-
AnyRankedTensor:$input,
587+
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input,
588588
DefaultValuedAttr<Mesh_ReductionKindAttr, "::mlir::mesh::ReductionKind::Sum">:$reduction
589589
));
590590
let results = (outs
591-
AnyRankedTensor:$result
591+
AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result
592592
);
593593
let assemblyFormat = [{
594594
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?

mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
6262
auto isEndomorphismOp = [reduction](Operation *op,
6363
std::optional<Operation *> referenceOp) {
6464
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
65-
if (!allReduceOp ||
66-
allReduceOp.getInput().getType().getElementType() !=
67-
allReduceOp.getResult().getType().getElementType() ||
65+
auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
66+
auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
67+
if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
6868
allReduceOp.getReduction() != reduction) {
6969
return false;
7070
}
@@ -83,9 +83,9 @@ void populateAllReduceEndomorphismSimplificationPatterns(
8383
}
8484

8585
auto refAllReduceOp = llvm::dyn_cast<AllReduceOp>(referenceOp.value());
86+
auto refType = cast<ShapedType>(refAllReduceOp.getResult().getType());
8687
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
87-
allReduceOp.getInput().getType().getElementType() ==
88-
refAllReduceOp.getInput().getType().getElementType();
88+
inType.getElementType() == refType.getElementType();
8989
};
9090
auto isAlgebraicOp = [](Operation *op) {
9191
return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));

mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
4242
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
4343
ArrayRef<MeshAxis> meshAxes,
4444
ImplicitLocOpBuilder &builder);
45+
// Get process linear index from a multi-index along the given mesh axes .
46+
TypedValue<IndexType>
47+
createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex,
48+
ArrayRef<MeshAxis> meshAxes,
49+
ImplicitLocOpBuilder &builder);
4550

4651
} // namespace mesh
4752
} // namespace mlir

mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class MPIImplTraits {
116116
/// enum value.
117117
virtual Value getMPIOp(const Location loc,
118118
ConversionPatternRewriter &rewriter,
119-
mpi::MPI_OpClassEnum opAttr) = 0;
119+
mpi::MPI_ReductionOpEnum opAttr) = 0;
120120
};
121121

122122
//===----------------------------------------------------------------------===//
@@ -199,49 +199,49 @@ class MPICHImplTraits : public MPIImplTraits {
199199
}
200200

201201
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
202-
mpi::MPI_OpClassEnum opAttr) override {
202+
mpi::MPI_ReductionOpEnum opAttr) override {
203203
int32_t op = MPI_NO_OP;
204204
switch (opAttr) {
205-
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
205+
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
206206
op = MPI_NO_OP;
207207
break;
208-
case mpi::MPI_OpClassEnum::MPI_MAX:
208+
case mpi::MPI_ReductionOpEnum::MPI_MAX:
209209
op = MPI_MAX;
210210
break;
211-
case mpi::MPI_OpClassEnum::MPI_MIN:
211+
case mpi::MPI_ReductionOpEnum::MPI_MIN:
212212
op = MPI_MIN;
213213
break;
214-
case mpi::MPI_OpClassEnum::MPI_SUM:
214+
case mpi::MPI_ReductionOpEnum::MPI_SUM:
215215
op = MPI_SUM;
216216
break;
217-
case mpi::MPI_OpClassEnum::MPI_PROD:
217+
case mpi::MPI_ReductionOpEnum::MPI_PROD:
218218
op = MPI_PROD;
219219
break;
220-
case mpi::MPI_OpClassEnum::MPI_LAND:
220+
case mpi::MPI_ReductionOpEnum::MPI_LAND:
221221
op = MPI_LAND;
222222
break;
223-
case mpi::MPI_OpClassEnum::MPI_BAND:
223+
case mpi::MPI_ReductionOpEnum::MPI_BAND:
224224
op = MPI_BAND;
225225
break;
226-
case mpi::MPI_OpClassEnum::MPI_LOR:
226+
case mpi::MPI_ReductionOpEnum::MPI_LOR:
227227
op = MPI_LOR;
228228
break;
229-
case mpi::MPI_OpClassEnum::MPI_BOR:
229+
case mpi::MPI_ReductionOpEnum::MPI_BOR:
230230
op = MPI_BOR;
231231
break;
232-
case mpi::MPI_OpClassEnum::MPI_LXOR:
232+
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
233233
op = MPI_LXOR;
234234
break;
235-
case mpi::MPI_OpClassEnum::MPI_BXOR:
235+
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
236236
op = MPI_BXOR;
237237
break;
238-
case mpi::MPI_OpClassEnum::MPI_MINLOC:
238+
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
239239
op = MPI_MINLOC;
240240
break;
241-
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
241+
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
242242
op = MPI_MAXLOC;
243243
break;
244-
case mpi::MPI_OpClassEnum::MPI_REPLACE:
244+
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
245245
op = MPI_REPLACE;
246246
break;
247247
}
@@ -336,49 +336,49 @@ class OMPIImplTraits : public MPIImplTraits {
336336
}
337337

338338
Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
339-
mpi::MPI_OpClassEnum opAttr) override {
339+
mpi::MPI_ReductionOpEnum opAttr) override {
340340
StringRef op;
341341
switch (opAttr) {
342-
case mpi::MPI_OpClassEnum::MPI_OP_NULL:
342+
case mpi::MPI_ReductionOpEnum::MPI_OP_NULL:
343343
op = "ompi_mpi_no_op";
344344
break;
345-
case mpi::MPI_OpClassEnum::MPI_MAX:
345+
case mpi::MPI_ReductionOpEnum::MPI_MAX:
346346
op = "ompi_mpi_max";
347347
break;
348-
case mpi::MPI_OpClassEnum::MPI_MIN:
348+
case mpi::MPI_ReductionOpEnum::MPI_MIN:
349349
op = "ompi_mpi_min";
350350
break;
351-
case mpi::MPI_OpClassEnum::MPI_SUM:
351+
case mpi::MPI_ReductionOpEnum::MPI_SUM:
352352
op = "ompi_mpi_sum";
353353
break;
354-
case mpi::MPI_OpClassEnum::MPI_PROD:
354+
case mpi::MPI_ReductionOpEnum::MPI_PROD:
355355
op = "ompi_mpi_prod";
356356
break;
357-
case mpi::MPI_OpClassEnum::MPI_LAND:
357+
case mpi::MPI_ReductionOpEnum::MPI_LAND:
358358
op = "ompi_mpi_land";
359359
break;
360-
case mpi::MPI_OpClassEnum::MPI_BAND:
360+
case mpi::MPI_ReductionOpEnum::MPI_BAND:
361361
op = "ompi_mpi_band";
362362
break;
363-
case mpi::MPI_OpClassEnum::MPI_LOR:
363+
case mpi::MPI_ReductionOpEnum::MPI_LOR:
364364
op = "ompi_mpi_lor";
365365
break;
366-
case mpi::MPI_OpClassEnum::MPI_BOR:
366+
case mpi::MPI_ReductionOpEnum::MPI_BOR:
367367
op = "ompi_mpi_bor";
368368
break;
369-
case mpi::MPI_OpClassEnum::MPI_LXOR:
369+
case mpi::MPI_ReductionOpEnum::MPI_LXOR:
370370
op = "ompi_mpi_lxor";
371371
break;
372-
case mpi::MPI_OpClassEnum::MPI_BXOR:
372+
case mpi::MPI_ReductionOpEnum::MPI_BXOR:
373373
op = "ompi_mpi_bxor";
374374
break;
375-
case mpi::MPI_OpClassEnum::MPI_MINLOC:
375+
case mpi::MPI_ReductionOpEnum::MPI_MINLOC:
376376
op = "ompi_mpi_minloc";
377377
break;
378-
case mpi::MPI_OpClassEnum::MPI_MAXLOC:
378+
case mpi::MPI_ReductionOpEnum::MPI_MAXLOC:
379379
op = "ompi_mpi_maxloc";
380380
break;
381-
case mpi::MPI_OpClassEnum::MPI_REPLACE:
381+
case mpi::MPI_ReductionOpEnum::MPI_REPLACE:
382382
op = "ompi_mpi_replace";
383383
break;
384384
}

0 commit comments

Comments
 (0)