@@ -13,6 +13,8 @@ include "mlir/Dialect/Mesh/IR/MeshBase.td"
13
13
include "mlir/Interfaces/InferTypeOpInterface.td"
14
14
include "mlir/Interfaces/SideEffectInterfaces.td"
15
15
include "mlir/IR/BuiltinTypes.td"
16
+ include "mlir/IR/CommonAttrConstraints.td"
17
+ include "mlir/IR/CommonTypeConstraints.td"
16
18
include "mlir/IR/SymbolInterfaces.td"
17
19
18
20
//===----------------------------------------------------------------------===//
@@ -77,6 +79,18 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
77
79
$sym_name `(` `rank` `=` $rank (`,` `dim_sizes` `=` $dim_sizes^)? `)`
78
80
attr-dict
79
81
}];
82
+ let extraClassDeclaration = [{
83
+ // The `dim_sizes` attribute may have size less than the rank of the mesh.
84
+ // Returns the shape of the mesh with missing trailing dimensions
85
+ // explicitly set as dynamic.
86
+ ::mlir::SmallVector<int64_t> canonicalDimSizes();
87
+
88
+ template <typename OutIt>
89
+ void canonicalDimSizes(OutIt outIt) {
90
+ std::copy(getDimSizes().begin(), getDimSizes().end(), outIt);
91
+ std::fill_n(outIt, getRank() - getDimSizes().size(), 0);
92
+ }
93
+ }];
80
94
let hasVerifier = 1;
81
95
}
82
96
@@ -171,4 +185,219 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
171
185
}];
172
186
}
173
187
188
+ //===----------------------------------------------------------------------===//
189
+ // collective communication ops
190
+ //===----------------------------------------------------------------------===//
191
+
192
+ class Mesh_CollectiveCommunicationOpBase<
193
+ string mnemonic, list<Trait> traits = []> :
194
+ Mesh_Op<mnemonic,
195
+ !listconcat(traits,
196
+ [DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
197
+ dag commonArgs = (ins
198
+ FlatSymbolRefAttr:$mesh,
199
+ DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
200
+ );
201
+ }
202
+
203
+ def Mesh_AllGatherOp : Mesh_CollectiveCommunicationOpBase<"all_gather", [
204
+ SameOperandsAndResultElementType,
205
+ SameOperandsAndResultRank
206
+ ]> {
207
+ let summary = "All-gather over a device mesh.";
208
+ let description = [{
209
+ Gathers along the `gather_axis` tensor axis.
210
+
211
+ Example:
212
+ ```mlir
213
+ mesh.cluster @mesh0(rank = 2, dim_sizes = [2, 2])
214
+ ...
215
+ %1 = mesh.all_gather %0 on @mesh0 mesh_axes = [1] gather_axis = 1
216
+ : tensor<2x2xi8> -> tensor<2x4xi8>
217
+ ```
218
+ Input:
219
+ ```
220
+ +-------+-------+
221
+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1)
222
+ | 3 4 | 7 8 |
223
+ +-------+-------+
224
+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1)
225
+ | 11 12 | 15 16 |
226
+ +-------+-------+
227
+ ```
228
+ Result:
229
+ ```
230
+ gather tensor
231
+ axis 1
232
+ ------------>
233
+ +-------------+
234
+ | 1 2 5 6 | <- devices (0, 0) and (0, 1)
235
+ | 3 4 7 8 |
236
+ +-------------+
237
+ | 9 10 13 14 | <- devices (1, 0) and (1, 1)
238
+ | 11 12 15 16 |
239
+ +-------------+
240
+ ```
241
+ }];
242
+ let arguments = !con(commonArgs, (ins
243
+ AnyNon0RankedTensor:$input,
244
+ IndexAttr:$gather_axis
245
+ ));
246
+ let results = (outs
247
+ AnyNon0RankedTensor:$result
248
+ );
249
+ let assemblyFormat = [{
250
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? `gather_axis` `=` $gather_axis
251
+ attr-dict `:` type($input) `->` type($result)
252
+ }];
253
+ let hasCanonicalizer = 1;
254
+ }
255
+
256
+ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
257
+ SameOperandsAndResultShape]> {
258
+ let summary = "All-reduce over a device mesh.";
259
+ let description = [{
260
+ The accumulation element type is specified by the result type and
261
+ it does not need to match the input element type.
262
+ The input element is converted to the result element type before
263
+ performing the reduction.
264
+
265
+ Attributes:
266
+ `reduction`: Indicates the reduction method.
267
+
268
+ Example:
269
+ ```
270
+ %1 = mesh.all_reduce %0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
271
+ : tensor<3x4xf32> -> tensor<3x4xf64>
272
+ ```
273
+ }];
274
+ let arguments = !con(commonArgs, (ins
275
+ AnyRankedTensor:$input,
276
+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction
277
+ ));
278
+ let results = (outs
279
+ AnyRankedTensor:$result
280
+ );
281
+ let assemblyFormat = [{
282
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)? (`reduction` `=` $reduction^)?
283
+ attr-dict `:` type($input) `->` type($result)
284
+ }];
285
+ let hasCanonicalizer = 1;
286
+ }
287
+
288
+ def Mesh_AllToAllOp : Mesh_CollectiveCommunicationOpBase<"all_to_all", [
289
+ SameOperandsAndResultElementType,
290
+ SameOperandsAndResultRank]> {
291
+ let summary = "All-to-all over a device mesh.";
292
+ let description = [{
293
+ Performs an all-to-all on tensor pieces split along `split_axis`.
294
+ The resulting pieces are concatenated along `concat_axis` on ech device.
295
+
296
+ Example:
297
+ ```
298
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [3])
299
+ ...
300
+ %1 = mesh.all_to_all %0 on @mesh0 mesh_axes = [0]
301
+ split_axis = 0 concat_axis = 0
302
+ : tensor<3x2xi8> -> tensor<3x2xi8>
303
+ ```
304
+ Input:
305
+ ```
306
+ device device device
307
+ (0) (1) (2)
308
+ +-------+-------+-------+ | split and concat along
309
+ | 11 12 | 21 22 | 31 32 | | tensor axis 0
310
+ | 13 14 | 23 24 | 33 34 | ↓
311
+ | 15 16 | 25 26 | 35 36 |
312
+ +-------+-------+-------+
313
+ ```
314
+ Result:
315
+ ```
316
+ device device device
317
+ (0) (1) (2)
318
+ +-------+-------+-------+
319
+ | 11 12 | 13 14 | 15 16 |
320
+ | 21 22 | 23 24 | 25 26 |
321
+ | 31 32 | 33 34 | 35 36 |
322
+ +-------+-------+-------+
323
+ ```
324
+ }];
325
+ let arguments = !con(commonArgs, (ins
326
+ AnyNon0RankedTensor:$input,
327
+ IndexAttr:$split_axis,
328
+ IndexAttr:$concat_axis
329
+ ));
330
+ let results = (outs
331
+ AnyNon0RankedTensor:$result
332
+ );
333
+ let assemblyFormat = [{
334
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
335
+ `split_axis` `=` $split_axis
336
+ `concat_axis` `=` $concat_axis
337
+ attr-dict `:` type($input) `->` type($result)
338
+ }];
339
+ let hasCanonicalizer = 1;
340
+ }
341
+
342
+ def Mesh_ReduceScatterOp : Mesh_CollectiveCommunicationOpBase<"reduce_scatter", [
343
+ SameOperandsAndResultRank]> {
344
+ let summary = "Reduce-scatter over a device mesh.";
345
+ let description = [{
346
+ After the reduction, the result is scattered within each device group.
347
+ The tensor is split along `scatter_axis` and the pieces distributed
348
+ across the device group.
349
+ Example:
350
+ ```
351
+ mesh.cluster @mesh0(rank = 1, dim_sizes = [2, 2])
352
+ ...
353
+ %1 = mesh.reduce_scatter %0 on @mesh0 mesh_axes = [1]
354
+ reduction = <max> scatter_axis = 0
355
+ : tensor<3x4xf32> -> tensor<1x4xf64>
356
+ ```
357
+ Input:
358
+ ```
359
+ device
360
+ (0, 1)
361
+ ↓
362
+ +-------+-------+ | scatter tensor
363
+ device (0, 0) -> | 1 2 | 5 6 | | axis 0
364
+ | 3 4 | 7 8 | ↓
365
+ +-------+-------+
366
+ device (1, 0) -> | 9 10 | 13 14 |
367
+ | 11 12 | 15 16 |
368
+ +-------+-------+
369
+ ↑
370
+ device
371
+ (1, 1)
372
+ ```
373
+ Result:
374
+ ```
375
+ +-------+
376
+ | 6 8 | <- devices (0, 0)
377
+ +-------+
378
+ | 10 12 | <- devices (0, 1)
379
+ +-------+
380
+ | 22 24 | <- devices (1, 0)
381
+ +-------+
382
+ | 26 28 | <- devices (1, 1)
383
+ +-------+
384
+ ```
385
+ }];
386
+ let arguments = !con(commonArgs, (ins
387
+ AnyNon0RankedTensor:$input,
388
+ DefaultValuedAttr<Mesh_PartialAttr, "::mlir::mesh::Partial::Sum">:$reduction,
389
+ IndexAttr:$scatter_axis
390
+ ));
391
+ let results = (outs
392
+ AnyRankedTensor:$result
393
+ );
394
+ let assemblyFormat = [{
395
+ $input `on` $mesh (`mesh_axes` `=` $mesh_axes^)?
396
+ (`reduction` `=` $reduction^)?
397
+ `scatter_axis` `=` $scatter_axis
398
+ attr-dict `:` type($input) `->` type($result)
399
+ }];
400
+ let hasCanonicalizer = 1;
401
+ }
402
+
174
403
#endif // MLIR_DIALECT_MESH_IR_MESHOPS_TD
0 commit comments