Skip to content

Commit 5f7c8c1

Browse files
authored
[mlir][mesh] Add collective communication operations (#71960)
Add all-gather, all-reduce, all-to-all and reduce-scatter. These operations have device mesh semantics.
1 parent ac75171 commit 5f7c8c1

File tree

8 files changed

+1179
-3
lines changed

8 files changed

+1179
-3
lines changed

mlir/docs/Dialects/Mesh.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# 'mesh' Dialect
2+
3+
The `mesh` dialect contains a set of attributes, operations and interfaces that
4+
are useful for representing sharding and communication on a device mesh
5+
cluster.
6+
7+
[TOC]
8+
9+
## Collective Communication Operations
10+
There are a number of operations in the Mesh dialect to facilitate
11+
communication between devices in a mesh.
12+
It is assumed that the user is familiar with collective operations.
13+
[Wikipedia](https://en.wikipedia.org/wiki/Collective_operation) has a good
14+
explanation.
15+
The main addition is that the collectives in this dialect have mesh
16+
semantics.
17+
18+
The operation attributes `mesh` and `mesh_axes` specifies a list of device mesh
19+
axes that partition the devices into disjoint groups.
20+
The collective operation is performed between devices in the same group.
21+
Devices that have the same coordinates outside of axes `mesh_axes` are in the
22+
same group.
23+
For example if we have a device mesh of size `2x3x4x5` and the partition mesh
24+
axes list is `[0, 1]` then devices are partitioned into the groups
25+
`{ { (i, j, k, m) | 0<=i<2, 0<=j<3 } | 0<=k<4, 0<=m<5 }`.
26+
Devices (1, 0, 2, 3) and (1, 1, 2, 3) will be in the same group.
27+
Device (1, 0, 2, 4) will be in another group.
28+
Some collective operations like all-to-all and all-gather care about the
29+
order of devices.
30+
The order of device in a device group is induced by the order of axes in
31+
`mesh_axes`.
32+
The axes are ordered from outer to inner.
33+
If we have an axis list `[3, 1]` then device `(i, 1, k, 0)` will precede
34+
both devices `(i, 0, k, 1)` and `(i, 2, k, 0)`.
35+
36+
37+
## Operations
38+
39+
[include "Dialects/MeshOps.md"]
40+
41+
## Attributes
42+
43+
[include "Dialects/MeshAttributes.md"]

mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def Mesh_Dialect : Dialect {
2323
let cppNamespace = "::mlir::mesh";
2424

2525
let description = [{
26-
The `mesh` dialect contains a set of attributes, operations, interfaces that
27-
are useful for representing sharding and communication on device mesh
28-
cluster.
26+
See [Mesh dialect documentation](mlir/docs/Dialects/Mesh.md).
2927
}];
3028

3129
let dependentDialects = [
@@ -49,6 +47,10 @@ def Mesh_Partial : I32EnumAttr<"Partial", "partial type of a distributed tensor"
4947
let cppNamespace = "::mlir::mesh";
5048
}
5149

50+
def Mesh_PartialAttr : EnumAttr<Mesh_Dialect, Mesh_Partial, "partial"> {
51+
let assemblyFormat = "`<` $value `>`";
52+
}
53+
5254
// Mesh_IteratorType and Mesh_Partial are used to annotate different aspects of
5355
// distributed tensors. Mesh_IteratorType annotates loops in an operation, while
5456
// Mesh_Partial indicates whether a tensor is sharded on a specific dimension or

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
#define MLIR_DIALECT_MESH_IR_MESHOPS_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/IR/OpDefinition.h"
14+
#include "mlir/IR/PatternMatch.h"
1315
#include "mlir/IR/SymbolTable.h"
1416
#include "mlir/Interfaces/InferTypeOpInterface.h"
1517
#include "mlir/Interfaces/SideEffectInterfaces.h"
18+
#include <algorithm>
1619

1720
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.h.inc"
1821

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
1313
include "mlir/Interfaces/InferTypeOpInterface.td"
1414
include "mlir/Interfaces/SideEffectInterfaces.td"
1515
include "mlir/IR/BuiltinTypes.td"
16+
include "mlir/IR/CommonAttrConstraints.td"
17+
include "mlir/IR/CommonTypeConstraints.td"
1618
include "mlir/IR/SymbolInterfaces.td"
1719

1820
//===----------------------------------------------------------------------===//
@@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
7779
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
7880
attr-dict
7981
}];
82+
let extraClassDeclaration = [{
83+
// The `dim_sizes` attribute may have size less than the rank of the mesh.
84+
// Returns the shape of the mesh with missing trailing dimensions
85+
// explicitly set as dynamic.
86+
::mlir::SmallVector<int64_t> canonicalDimSizes();
87+
88+
template <typename OutIt>
89+
void canonicalDimSizes(OutIt outIt) {
90+
std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
91+
std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
92+
}
93+
}];
8094
let hasVerifier = 1;
8195
}
8296

@@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
171185
}];
172186
}
173187

188+
//===----------------------------------------------------------------------===//
189+
// collective communication ops
190+
//===----------------------------------------------------------------------===//
191+
192+
class Mesh_CollectiveCommunicationOpBase<
193+
string mnemonic, list<Trait> traits = []> :
194+
Mesh_Op<mnemonic,
195+
!listconcat(traits,
196+
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
197+
dag commonArgs = (ins
198+
FlatSymbolRefAttr:$mesh,
199+
DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
200+
);
201+
}
202+
203+
def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
204+
SameOperandsAndResultElementType,
205+
SameOperandsAndResultRank
206+
]> {
207+
let summary = "All-gather over a device mesh.";
208+
let description = [{
209+
Gathers along the `gather_axis` tensor axis.
210+
211+
Example:
212+
```mlir
213+
mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
214+
...
215+
%1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
216+
: tensor<2x2xi8> -> tensor<2x4xi8>
217+
```
218+
Input:
219+
```
220+
+-------+-------+
221+
device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
222+
| 3 4 | 7 8 |
223+
+-------+-------+
224+
device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
225+
| 11 12 | 15 16 |
226+
+-------+-------+
227+
```
228+
Result:
229+
```
230+
gather tensor
231+
axis 1
232+
------------>
233+
+-------------+
234+
| 1 2 5 6 | <- devices (0, 0) and (0, 1)
235+
| 3 4 7 8 |
236+
+-------------+
237+
| 9 10 13 14 | <- devices (1, 0) and (1, 1)
238+
| 11 12 15 16 |
239+
+-------------+
240+
```
241+
}];
242+
let arguments = !con(commonArgs, (ins
243+
AnyNon0RankedTensor:$input,
244+
IndexAttr:$gather_axis
245+
));
246+
let results = (outs
247+
AnyNon0RankedTensor:$result
248+
);
249+
let assemblyFormat = [{
250+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
251+
attr-dict `:` type($input) `->` type($result)
252+
}];
253+
let hasCanonicalizer = 1;
254+
}
255+
256+
def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
257+
SameOperandsAndResultShape]> {
258+
let summary = "All-reduce over a device mesh.";
259+
let description = [{
260+
The accumulation element type is specified by the result type and
261+
it does not need to match the input element type.
262+
The input element is converted to the result element type before
263+
performing the reduction.
264+
265+
Attributes:
266+
`reduction`: Indicates the reduction method.
267+
268+
Example:
269+
```
270+
%1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
271+
: tensor<3x4xf32> -> tensor<3x4xf64>
272+
```
273+
}];
274+
let arguments = !con(commonArgs, (ins
275+
AnyRankedTensor:$input,
276+
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
277+
));
278+
let results = (outs
279+
AnyRankedTensor:$result
280+
);
281+
let assemblyFormat = [{
282+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
283+
attr-dict `:` type($input) `->` type($result)
284+
}];
285+
let hasCanonicalizer = 1;
286+
}
287+
288+
def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
289+
SameOperandsAndResultElementType,
290+
SameOperandsAndResultRank]> {
291+
let summary = "All-to-all over a device mesh.";
292+
let description = [{
293+
Performs an all-to-all on tensor pieces split along `split_axis`.
294+
The resulting pieces are concatenated along `concat_axis` on ech device.
295+
296+
Example:
297+
```
298+
mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
299+
...
300+
%1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
301+
split_axis = 0 concat_axis = 0
302+
: tensor<3x2xi8> -> tensor<3x2xi8>
303+
```
304+
Input:
305+
```
306+
device device device
307+
(0) (1) (2)
308+
+-------+-------+-------+ | split and concat along
309+
| 11 12 | 21 22 | 31 32 | | tensor axis 0
310+
| 13 14 | 23 24 | 33 34 | ↓
311+
| 15 16 | 25 26 | 35 36 |
312+
+-------+-------+-------+
313+
```
314+
Result:
315+
```
316+
device device device
317+
(0) (1) (2)
318+
+-------+-------+-------+
319+
| 11 12 | 13 14 | 15 16 |
320+
| 21 22 | 23 24 | 25 26 |
321+
| 31 32 | 33 34 | 35 36 |
322+
+-------+-------+-------+
323+
```
324+
}];
325+
let arguments = !con(commonArgs, (ins
326+
AnyNon0RankedTensor:$input,
327+
IndexAttr:$split_axis,
328+
IndexAttr:$concat_axis
329+
));
330+
let results = (outs
331+
AnyNon0RankedTensor:$result
332+
);
333+
let assemblyFormat = [{
334+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
335+
`split_axis` `=` $split_axis
336+
`concat_axis` `=` $concat_axis
337+
attr-dict `:` type($input) `->` type($result)
338+
}];
339+
let hasCanonicalizer = 1;
340+
}
341+
342+
def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
343+
SameOperandsAndResultRank]> {
344+
let summary = "Reduce-scatter over a device mesh.";
345+
let description = [{
346+
After the reduction, the result is scattered within each device group.
347+
The tensor is split along `scatter_axis` and the pieces distributed
348+
across the device group.
349+
Example:
350+
```
351+
mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
352+
...
353+
%1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
354+
reduction = <max> scatter_axis = 0
355+
: tensor<3x4xf32> -> tensor<1x4xf64>
356+
```
357+
Input:
358+
```
359+
device
360+
(0, 1)
361+
362+
+-------+-------+ | scatter tensor
363+
device (0, 0) -> | 1 2 | 5 6 | | axis 0
364+
| 3 4 | 7 8 | ↓
365+
+-------+-------+
366+
device (1, 0) -> | 9 10 | 13 14 |
367+
| 11 12 | 15 16 |
368+
+-------+-------+
369+
370+
device
371+
(1, 1)
372+
```
373+
Result:
374+
```
375+
+-------+
376+
| 6 8 | <- devices (0, 0)
377+
+-------+
378+
| 10 12 | <- devices (0, 1)
379+
+-------+
380+
| 22 24 | <- devices (1, 0)
381+
+-------+
382+
| 26 28 | <- devices (1, 1)
383+
+-------+
384+
```
385+
}];
386+
let arguments = !con(commonArgs, (ins
387+
AnyNon0RankedTensor:$input,
388+
DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
389+
IndexAttr:$scatter_axis
390+
));
391+
let results = (outs
392+
AnyRankedTensor:$result
393+
);
394+
let assemblyFormat = [{
395+
$input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
396+
(`reduction` `=` $reduction^)?
397+
`scatter_axis` `=` $scatter_axis
398+
attr-dict `:` type($input) `->` type($result)
399+
}];
400+
let hasCanonicalizer = 1;
401+
}
402+
174403
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD

0 commit comments

Comments
 (0)