Skip to content

Commit 3b5187f

Browse files
Tobias Gysimemfrob
authored andcommitted
[mlir][linalg] Update OpDSL to use the newly introduced min and max ops.
Implement min and max using the newly introduced std operations instead of relying on compare and select. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D111170
1 parent 06e2a47 commit 3b5187f

File tree

5 files changed

+23
-52
lines changed

5 files changed

+23
-52
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -276,18 +276,20 @@ class RegionBuilderHelper {
276276
}
277277

278278
Value applyfn__max(Value lhs, Value rhs) {
279+
OpBuilder builder = getBuilder();
279280
if (isFloatingPoint(lhs))
280-
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
281+
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
281282
if (isInteger(lhs))
282-
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
283+
return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
283284
llvm_unreachable("unsupported non numeric type");
284285
}
285286

286287
Value applyfn__min(Value lhs, Value rhs) {
288+
OpBuilder builder = getBuilder();
287289
if (isFloatingPoint(lhs))
288-
return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
290+
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
289291
if (isInteger(lhs))
290-
return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
292+
return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
291293
llvm_unreachable("unsupported non numeric type");
292294
}
293295

@@ -324,17 +326,6 @@ class RegionBuilderHelper {
324326
MLIRContext *context;
325327
Block &block;
326328

327-
Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
328-
OpBuilder builder = getBuilder();
329-
Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
330-
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
331-
}
332-
Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
333-
OpBuilder builder = getBuilder();
334-
Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
335-
return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
336-
}
337-
338329
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
339330
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
340331

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -319,20 +319,16 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
319319

320320
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
321321
if _is_floating_point_type(lhs.type):
322-
ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
323-
return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
322+
return std.MaxFOp(lhs.type, lhs, rhs).result
324323
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
325-
sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
326-
return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
324+
return std.MaxSIOp(lhs.type, lhs, rhs).result
327325
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
328326

329327
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
330328
if _is_floating_point_type(lhs.type):
331-
olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
332-
return _emit_cmpf_and_select(lhs, rhs, olt_attr)
329+
return std.MinFOp(lhs.type, lhs, rhs).result
333330
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
334-
slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
335-
return _emit_cmpi_and_select(lhs, rhs, slt_attr)
331+
return std.MinSIOp(lhs.type, lhs, rhs).result
336332
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
337333

338334

@@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int:
413409
if BF16Type.isinstance(t):
414410
return 16
415411
raise NotImplementedError(f"Unhandled floating point type switch {t}")
416-
417-
418-
def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
419-
cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
420-
return std.SelectOp(lhs.type, cond, lhs, rhs).result
421-
422-
423-
def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
424-
cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
425-
return std.SelectOp(lhs.type, cond, lhs, rhs).result

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
3838

3939
// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
4040
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
41-
// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32
42-
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
41+
// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
4342
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
4443
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
4544

@@ -53,8 +52,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
5352

5453
// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
5554
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
56-
// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32
57-
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
55+
// CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT_ARG]], %[[IN_ARG]] : i32
5856
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
5957
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
6058

@@ -68,9 +66,8 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
6866

6967
// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
7068
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
71-
// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
72-
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
73-
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
69+
// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32
70+
// CHECK-NEXT: linalg.yield %[[MIN]] : f32
7471
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
7572

7673
// -----
@@ -83,9 +80,8 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
8380

8481
// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
8582
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
86-
// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
87-
// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
88-
// CHECK-NEXT: linalg.yield %[[MAX]] : i32
83+
// CHECK-NEXT: %[[MIN:.+]] = minsi %[[OUT_ARG]], %[[IN_ARG]] : i32
84+
// CHECK-NEXT: linalg.yield %[[MIN]] : i32
8985
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
9086

9187
// -----

mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,7 @@ def test_f32i32_conv(input, filter, init_result):
242242
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
243243
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
244244
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
245-
# CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT]], %[[IN_CAST:.+]] : i32
246-
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN_CAST:.+]] : i32
245+
# CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
247246
# CHECK-NEXT: linalg.yield %[[MAX]] : i32
248247
# CHECK-NEXT: -> tensor<2x4xi32>
249248
@builtin.FuncOp.from_py_func(
@@ -258,8 +257,7 @@ def test_f32i32_max_pooling(input, shape, init_result):
258257
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
259258
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
260259
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
261-
# CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT]], %[[IN:.+]] : f32
262-
# CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT]], %[[IN:.+]] : f32
260+
# CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32
263261
# CHECK-NEXT: linalg.yield %[[MAX]] : f32
264262
# CHECK-NEXT: -> tensor<2x4xf32>
265263
@builtin.FuncOp.from_py_func(
@@ -270,7 +268,7 @@ def test_f32f32_max_pooling(input, shape, init_result):
270268
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
271269

272270
# CHECK-LABEL: @test_f32i32_min_pooling
273-
# CHECK: = cmpi slt,
271+
# CHECK: = minsi
274272
@builtin.FuncOp.from_py_func(
275273
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
276274
RankedTensorType.get((2, 4), i32))
@@ -279,7 +277,7 @@ def test_f32i32_min_pooling(input, shape, init_result):
279277
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
280278

281279
# CHECK-LABEL: @test_f32f32_min_pooling
282-
# CHECK: = cmpf olt,
280+
# CHECK: = minf
283281
@builtin.FuncOp.from_py_func(
284282
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
285283
RankedTensorType.get((2, 4), f32))

mlir/test/python/integration/dialects/linalg/opsrun.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def log(*args):
118118

119119
def transform(module, boilerplate):
120120
import mlir.conversions
121-
import mlir.dialects.linalg.passes
121+
import mlir.all_passes_registration
122122
import mlir.transforms
123123

124124
# TODO: Allow cloning functions from one module to another.
@@ -128,8 +128,8 @@ def transform(module, boilerplate):
128128
boilerplate)
129129
pm = PassManager.parse(
130130
"builtin.func(convert-linalg-to-loops, lower-affine, " +
131-
"convert-scf-to-std), convert-vector-to-llvm," +
132-
"convert-memref-to-llvm,convert-std-to-llvm," +
131+
"convert-scf-to-std, std-expand), convert-vector-to-llvm," +
132+
"convert-memref-to-llvm, convert-std-to-llvm," +
133133
"reconcile-unrealized-casts")
134134
pm.run(mod)
135135
return mod

0 commit comments

Comments
 (0)