Skip to content

Commit 479ce0e

Browse files
[TOSA] TOSA updates for LLVM hash b3b0070
1: [TOSA] Update rescale input_/output_zp and double_round attribute * Update tosa.rescale input_/output_zp as inputs according to TOSA 1.0 * Update double_round bool attribute to rounding_mode in alignment with TOSA 1.0. rounding_mode supports "SINGLE_ROUND", "INEXACT_ROUND", and "DOUBLE_ROUND". Existing double_round behaviours are mapped as followed: - double_round = true -> rounding_mode = "DOUBLE_ROUND" - double_round = false -> rounding_mode = "SINGLE_ROUND" 2: [TOSA] Update tosa.negate's zero-points to inputs Update LIT tests and XFAIL sets
1 parent e4a2f86 commit 479ce0e

File tree

5 files changed

+66
-35
lines changed

5 files changed

+66
-35
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace tosa {
2626
// rounding mode
2727
Value buildRescale(PatternRewriter &rewriter, Operation *op,
2828
ShapedType output_type, Value input_val, double scale,
29-
int64_t input_zp, int64_t output_zp, bool double_round,
29+
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
3030
bool scale32);
3131

3232
// Creates TOSA rescale op with int32 output

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ std::optional<Value> convertReduceOpCommon(
777777
RankedTensorType output_rescale_type =
778778
RankedTensorType::get(shape_vec, output_type.getElementType());
779779
val = buildRescale(rewriter, op, output_rescale_type, val, output_scale,
780-
0, output_zp, false, true);
780+
0, output_zp, "SINGLE_ROUND", true);
781781
}
782782

783783
// Optionally squeeze out the reduced axes.

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

+44-20
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
3434
// rounding mode
3535
Value buildRescale(PatternRewriter &rewriter, Operation *op,
3636
ShapedType output_type, Value input_val, double scale,
37-
int64_t input_zp, int64_t output_zp, bool double_round,
37+
int64_t input_zp, int64_t output_zp, StringRef rounding_mode,
3838
bool scale32) {
3939
int32_t multiplier;
4040
int32_t shift;
4141

4242
int32_t scale_width = scale32 ? 32 : 16;
4343

44-
computeMultiplierAndShift(scale, multiplier, shift, scale_width);
44+
if (!computeMultiplierAndShift(scale, multiplier, shift, scale_width))
45+
op->emitError("buildRescale: shift must be in the range 2 <= shift <= 62");
4546

4647
Value multiplier_val =
4748
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
@@ -52,11 +53,23 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
5253
bool input_unsigned = input_val.getType().isUnsignedInteger();
5354
bool output_unsigned = output_type.isUnsignedInteger();
5455

56+
// Create input_zp matches the input type and output_zp matches the output
57+
// type of RescaleOp
58+
const auto input_zp_val = tosa::createZeroPointTensor(
59+
rewriter, op->getLoc(), dyn_cast<TensorType>(input_val.getType()),
60+
input_zp);
61+
if (!input_zp_val.has_value())
62+
op->emitError("Failed to create input zero-point tensor for RescaleOp.");
63+
64+
const auto output_zp_val = tosa::createZeroPointTensor(
65+
rewriter, op->getLoc(), output_type, output_zp);
66+
if (!output_zp_val.has_value())
67+
op->emitError("Failed to create output zero-point tensor for RescaleOp.");
68+
5569
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
5670
rewriter, op->getLoc(), output_type, input_val, multiplier_val, shift_val,
57-
rewriter.getI32IntegerAttr(static_cast<int32_t>(input_zp)),
58-
rewriter.getI32IntegerAttr(static_cast<int32_t>(output_zp)),
59-
rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round),
71+
input_zp_val.value(), output_zp_val.value(),
72+
rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode),
6073
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
6174
rewriter.getBoolAttr(output_unsigned));
6275

@@ -73,7 +86,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
7386
auto output_type = input_type.clone(rewriter.getI32Type());
7487

7588
return buildRescale(rewriter, op, output_type, input_val, input_scale,
76-
input_zp, 0, false, true);
89+
input_zp, 0, "SINGLE_ROUND", true);
7790
}
7891

7992
// Creates a TOSA rescale op based on conv2d parameters.
@@ -96,6 +109,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
96109
bool input_unsigned = input_qtype.isUnsignedInteger();
97110
bool output_unsigned = output_qtype.isUnsignedInteger();
98111

112+
const auto input_zp_val = tosa::createZeroPointTensor(
113+
rewriter, op->getLoc(), input_type, static_cast<int64_t>(0));
114+
if (!input_zp_val.has_value())
115+
op->emitError("Failed to create input zero-point tensor for RescaleOp.");
116+
117+
const auto output_zp_val = tosa::createZeroPointTensor(
118+
rewriter, op->getLoc(), output_type, output_zp);
119+
if (!output_zp_val.has_value())
120+
op->emitError("Failed to create output zero-point tensor for RescaleOp.");
121+
99122
if (auto weight_per_tensor_qtype =
100123
dyn_cast<mlir::quant::UniformQuantizedType>(
101124
weight_type.getElementType())) {
@@ -107,7 +130,11 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
107130

108131
double op_tensor_scale = (input_scale * weight_scale) / output_scale;
109132

110-
computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
133+
if (!computeMultiplierAndShift(op_tensor_scale, multiplier, shift,
134+
scale_width))
135+
op->emitError(
136+
"buildRescaleOpConvOutput: shift must be in the range 2 <= shift <= "
137+
"62");
111138

112139
Value multiplier_val =
113140
buildRescaleMultiplier(scale32, rewriter, op, {multiplier});
@@ -117,10 +144,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
117144

118145
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
119146
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
120-
shift_val, rewriter.getI32IntegerAttr(0),
121-
rewriter.getI32IntegerAttr(output_zp), rewriter.getBoolAttr(scale32),
122-
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false),
123-
rewriter.getBoolAttr(input_unsigned),
147+
shift_val, input_zp_val.value(), output_zp_val.value(),
148+
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
149+
rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
124150
rewriter.getBoolAttr(output_unsigned));
125151

126152
return rescale_op.getResult();
@@ -136,17 +162,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
136162
weight_per_channel_qtype.getScales().begin(),
137163
weight_per_channel_qtype.getScales().end());
138164

139-
int64_t output_zp = output_qtype.getZeroPoint();
140-
double output_scale = output_qtype.getScale();
141-
142165
for (double weight_scale : weight_scale_arr) {
143166
int32_t multiplier;
144167
int32_t shift;
145168

146169
double op_channel_scale = (input_scale * weight_scale) / output_scale;
147170

148-
computeMultiplierAndShift(op_channel_scale, multiplier, shift,
149-
scale_width);
171+
if (!computeMultiplierAndShift(op_channel_scale, multiplier, shift, 32))
172+
op->emitError(
173+
"buildRescaleOpConvOutput: shift must be in the range 2 <= shift "
174+
"<= 62");
150175

151176
multiplier_arr.push_back(multiplier);
152177
shift_arr.push_back(static_cast<int8_t>(shift));
@@ -161,10 +186,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
161186

162187
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
163188
rewriter, op->getLoc(), output_type, conv_val, multiplier_val,
164-
shift_val, rewriter.getI32IntegerAttr(0),
165-
rewriter.getI32IntegerAttr(output_zp), rewriter.getBoolAttr(scale32),
166-
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
167-
rewriter.getBoolAttr(input_unsigned),
189+
shift_val, input_zp_val.value(), output_zp_val.value(),
190+
rewriter.getBoolAttr(scale32), rewriter.getStringAttr("DOUBLE_ROUND"),
191+
rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned),
168192
rewriter.getBoolAttr(output_unsigned));
169193

170194
return rescale_op.getResult();

projects/pt1/e2e_testing/xfail_sets.py

+2
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,8 @@
17171717
"ScatterSrcModule_basic",
17181718
"ScatterSrcStaticModule_basic",
17191719
"HBC_basic",
1720+
# 1D inputs cause generated tosa.negate ops to crash downstream
1721+
"NllLossModule_1D_basic",
17201722
}
17211723

17221724
# Write the TOSA set as a "passing" set as it is very early in development

test/Conversion/TorchToTosa/basic.mlir

+18-13
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,14 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten
9494
// -----
9595

9696
// CHECK-LABEL: func.func @torch.aten.neg$basic(
97-
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
98-
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
99-
// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.negate %[[ARG_BUILTIN]] : (tensor<?x?xf32>) -> tensor<?x?xf32>
100-
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
101-
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
97+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
98+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
99+
// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
100+
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
101+
// CHECK: %[[VAL_4:.*]] = tosa.negate %[[VAL_1]], %[[VAL_2]], %[[VAL_3]] : (tensor<?x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x?xf32>
102+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
103+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
104+
// CHECK: }
102105
func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
103106
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
104107
return %0 : !torch.vtensor<[?,?],f32>
@@ -1555,20 +1558,22 @@ func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.v
15551558
// -----
15561559

15571560
// CHECK-LABEL: func.func @torch.aten.min.dim$basic(
1558-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
1561+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
15591562
// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>
15601563
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32>
15611564
// CHECK: %[[VAL_3:.*]] = torch.constant.bool true
15621565
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
1563-
// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
1566+
// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[3, 2]> : tensor<2xindex>} : () -> !tosa.shape<2>
15641567
// CHECK: %[[VAL_6:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32>
15651568
// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32>
1566-
// CHECK: %[[VAL_8:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32>
1567-
// CHECK: %[[VAL_9:.*]] = tosa.argmax %[[VAL_8]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
1568-
// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[3, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1569-
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_10]] : (tensor<3x2xi64>, !tosa.shape<3>) -> tensor<3x2x1xi64>
1570-
// CHECK: %[[VAL_12:.*]] = torch_c.to_builtin_tensor %[[VAL_7]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
1571-
// CHECK: return %[[VAL_12]] : tensor<3x2x1xf32>
1569+
// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1570+
// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1571+
// CHECK: %[[VAL_10:.*]] = tosa.negate %[[VAL_2]], %[[VAL_8]], %[[VAL_9]] : (tensor<3x2x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3x2x3xf32>
1572+
// CHECK: %[[VAL_11:.*]] = tosa.argmax %[[VAL_10]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64>
1573+
// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[3, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1574+
// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_12]] : (tensor<3x2xi64>, !tosa.shape<3>) -> tensor<3x2x1xi64>
1575+
// CHECK: %[[VAL_14:.*]] = torch_c.to_builtin_tensor %[[VAL_7]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32>
1576+
// CHECK: return %[[VAL_14]] : tensor<3x2x1xf32>
15721577
// CHECK: }
15731578
func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> {
15741579
%0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32>

0 commit comments

Comments
 (0)