Skip to content

Commit d03f30f

Browse files
authored
[mlir][TOSA] restore unrealized casts when lowering rescale ops (llvm#141096)
Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective. The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal `arith` casts that are constrained to signless types. Whether the `arith` casts themselves are signed or unsigned should still depend on the rescale's `*_unsigned` attribute values. --------- Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
1 parent dff6aee commit d03f30f

File tree

2 files changed

+148
-30
lines changed

2 files changed

+148
-30
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,15 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14921492
: blockArgs[multiplierArg];
14931493
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
14941494

1495+
if (valueTy.isUnsignedInteger()) {
1496+
value = nestedBuilder
1497+
.create<UnrealizedConversionCastOp>(
1498+
nestedLoc,
1499+
nestedBuilder.getIntegerType(
1500+
valueTy.getIntOrFloatBitWidth()),
1501+
value)
1502+
.getResult(0);
1503+
}
14951504
if (valueTy.getIntOrFloatBitWidth() < 32) {
14961505
if (op.getInputUnsigned()) {
14971506
value = nestedBuilder.create<arith::ExtUIOp>(
@@ -1537,6 +1546,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15371546
value);
15381547
}
15391548

1549+
if (outIntType.isUnsignedInteger()) {
1550+
value = nestedBuilder
1551+
.create<UnrealizedConversionCastOp>(nestedLoc,
1552+
outIntType, value)
1553+
.getResult(0);
1554+
}
15401555
nestedBuilder.create<linalg::YieldOp>(loc, value);
15411556
});
15421557

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 133 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,25 +1152,60 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
11521152
// -----
11531153
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
11541154

1155-
// CHECK-LABEL: @rescale_i8_unsigned_output
1155+
// CHECK-LABEL: @rescale_i8_unsigned_output_explicit
11561156
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1157-
func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
1157+
func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
1158+
// CHECK: [[C0:%.+]] = arith.constant 19689
1159+
// CHECK: [[C1:%.+]] = arith.constant 15
1160+
// CHECK: [[INIT:%.+]] = tensor.empty()
1161+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
1162+
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
1163+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1164+
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
1165+
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
1166+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1167+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1168+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1169+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1170+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1171+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1172+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1173+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1174+
// CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1175+
// CHECK: linalg.yield [[TRUNC_ITOU]]
1176+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1177+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1178+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1179+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1180+
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
1181+
1182+
// CHECK: return
1183+
return
1184+
}
1185+
1186+
// -----
1187+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1188+
1189+
// CHECK-LABEL: @rescale_i8_unsigned_output_implicit
1190+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1191+
func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
11581192
// CHECK: [[C0:%.+]] = arith.constant 19689
11591193
// CHECK: [[C1:%.+]] = arith.constant 15
11601194
// CHECK: [[INIT:%.+]] = tensor.empty()
11611195
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
11621196
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1163-
// CHECK: [[C17:%.+]] = arith.constant 17
1164-
// CHECK: [[C234:%.+]] = arith.constant 234
1197+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1198+
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
11651199
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
11661200
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
11671201
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
11681202
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
11691203
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
11701204
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
11711205
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1172-
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1173-
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1206+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1207+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1208+
// CHECK-NOT: builtin.unrealized_conversion_cast [[TRUNC]]
11741209
// CHECK: linalg.yield [[TRUNC]]
11751210
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
11761211
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
@@ -1182,6 +1217,39 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
11821217
return
11831218
}
11841219

1220+
// -----
1221+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
1222+
1223+
// CHECK-LABEL: @rescale_i48_unsigned_output_implicit
1224+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1225+
func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () {
1226+
// CHECK: [[C19689:%.+]] = arith.constant 19689
1227+
// CHECK: [[C15:%.+]] = arith.constant 15
1228+
// CHECK: [[INIT:%.+]] = tensor.empty()
1229+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
1230+
// CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
1231+
// CHECK-NOT: builtin.unrealized_conversion_cast [[IN]]
1232+
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
1233+
// CHECK-DAG: [[C234:%.+]] = arith.constant 234
1234+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
1235+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
1236+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1237+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1238+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1239+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1240+
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1241+
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1242+
// CHECK: linalg.yield [[TRUNC]]
1243+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1244+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1245+
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
1246+
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1247+
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
1248+
1249+
// CHECK: return
1250+
return
1251+
}
1252+
11851253
// -----
11861254

11871255
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -1230,19 +1298,52 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
12301298
}
12311299

12321300
// -----
1301+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12331302

1303+
// CHECK-LABEL: @rescale_i8_unsigned_input_explicit
1304+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1305+
func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
1306+
// CHECK: [[C0:%.+]] = arith.constant 19689
1307+
// CHECK: [[C1:%.+]] = arith.constant 15
1308+
// CHECK: [[INIT:%.+]] = tensor.empty()
1309+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
1310+
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
1311+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1312+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
1313+
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1314+
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
1315+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1316+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1317+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1318+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1319+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
1320+
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1321+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1322+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1323+
// CHECK: linalg.yield [[TRUNC]]
1324+
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
1325+
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1326+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1327+
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
1328+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
1329+
1330+
return
1331+
}
1332+
1333+
// -----
12341334
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12351335

1236-
// CHECK-LABEL: @rescale_i8_unsigned_input
1336+
// CHECK-LABEL: @rescale_i8_unsigned_input_implicit
12371337
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1238-
func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
1338+
func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
12391339
// CHECK: [[C0:%.+]] = arith.constant 19689
12401340
// CHECK: [[C1:%.+]] = arith.constant 15
12411341
// CHECK: [[INIT:%.+]] = tensor.empty()
12421342
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
12431343
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
1244-
// CHECK: [[C128:%.+]] = arith.constant 128
1245-
// CHECK: [[C22:%.+]] = arith.constant 22
1344+
// CHECK-NOT: builtin.unrealized_conversion_cast [[IN]]
1345+
// CHECK-DAG: [[C128:%.+]] = arith.constant 128
1346+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
12461347
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
12471348
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
12481349
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1265,32 +1366,34 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
12651366
// -----
12661367
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
12671368

1268-
// CHECK-LABEL: @rescale_i48_unsigned_output
1369+
// CHECK-LABEL: @rescale_i8_unsigned_input_output_explicit
12691370
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
1270-
func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
1271-
// CHECK: [[C19689:%.+]] = arith.constant 19689
1272-
// CHECK: [[C15:%.+]] = arith.constant 15
1371+
func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> () {
1372+
// CHECK: [[C0:%.+]] = arith.constant 19689
1373+
// CHECK: [[C1:%.+]] = arith.constant 15
12731374
// CHECK: [[INIT:%.+]] = tensor.empty()
1274-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
1275-
// CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
1276-
// CHECK: [[C0:%.+]] = arith.constant 0
1277-
// CHECK: [[C234:%.+]] = arith.constant 234
1278-
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
1279-
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
1280-
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
1281-
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
1282-
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
1375+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xui8>)
1376+
// CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: ui8):
1377+
// CHECK-DAG: [[C17:%.+]] = arith.constant 17
1378+
// CHECK-DAG: [[C22:%.+]] = arith.constant 22
1379+
// CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
1380+
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
1381+
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
1382+
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
1383+
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
1384+
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
1385+
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
12831386
// CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
1284-
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1285-
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1286-
// CHECK: linalg.yield [[TRUNC]]
1387+
// CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
1388+
// CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
1389+
// CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
1390+
// CHECK: linalg.yield [[TRUNC_ITOU]]
12871391
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
12881392
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
1289-
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
1290-
%output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
1291-
%1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
1393+
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
1394+
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
1395+
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
12921396

1293-
// CHECK: return
12941397
return
12951398
}
12961399

0 commit comments

Comments
 (0)