@@ -34,14 +34,15 @@ Value buildRescaleMultiplier(bool scale32, PatternRewriter &rewriter,
34
34
// rounding mode
35
35
Value buildRescale (PatternRewriter &rewriter, Operation *op,
36
36
ShapedType output_type, Value input_val, double scale,
37
- int64_t input_zp, int64_t output_zp, bool double_round ,
37
+ int64_t input_zp, int64_t output_zp, StringRef rounding_mode ,
38
38
bool scale32) {
39
39
int32_t multiplier;
40
40
int32_t shift;
41
41
42
42
int32_t scale_width = scale32 ? 32 : 16 ;
43
43
44
- computeMultiplierAndShift (scale, multiplier, shift, scale_width);
44
+ if (!computeMultiplierAndShift (scale, multiplier, shift, scale_width))
45
+ op->emitError (" buildRescale: shift must be in the range 2 <= shift <= 62" );
45
46
46
47
Value multiplier_val =
47
48
buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
@@ -52,11 +53,23 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op,
52
53
bool input_unsigned = input_val.getType ().isUnsignedInteger ();
53
54
bool output_unsigned = output_type.isUnsignedInteger ();
54
55
56
+ // Create input_zp matches the input type and output_zp matches the output
57
+ // type of RescaleOp
58
+ const auto input_zp_val = tosa::createZeroPointTensor (
59
+ rewriter, op->getLoc (), dyn_cast<TensorType>(input_val.getType ()),
60
+ input_zp);
61
+ if (!input_zp_val.has_value ())
62
+ op->emitError (" Failed to create input zero-point tensor for RescaleOp." );
63
+
64
+ const auto output_zp_val = tosa::createZeroPointTensor (
65
+ rewriter, op->getLoc (), output_type, output_zp);
66
+ if (!output_zp_val.has_value ())
67
+ op->emitError (" Failed to create output zero-point tensor for RescaleOp." );
68
+
55
69
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
56
70
rewriter, op->getLoc (), output_type, input_val, multiplier_val, shift_val,
57
- rewriter.getI32IntegerAttr (static_cast <int32_t >(input_zp)),
58
- rewriter.getI32IntegerAttr (static_cast <int32_t >(output_zp)),
59
- rewriter.getBoolAttr (scale32), rewriter.getBoolAttr (double_round),
71
+ input_zp_val.value (), output_zp_val.value (),
72
+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (rounding_mode),
60
73
rewriter.getBoolAttr (false ), rewriter.getBoolAttr (input_unsigned),
61
74
rewriter.getBoolAttr (output_unsigned));
62
75
@@ -73,7 +86,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
73
86
auto output_type = input_type.clone (rewriter.getI32Type ());
74
87
75
88
return buildRescale (rewriter, op, output_type, input_val, input_scale,
76
- input_zp, 0 , false , true );
89
+ input_zp, 0 , " SINGLE_ROUND " , true );
77
90
}
78
91
79
92
// Creates a TOSA rescale op based on conv2d parameters.
@@ -96,6 +109,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
96
109
bool input_unsigned = input_qtype.isUnsignedInteger ();
97
110
bool output_unsigned = output_qtype.isUnsignedInteger ();
98
111
112
+ const auto input_zp_val = tosa::createZeroPointTensor (
113
+ rewriter, op->getLoc (), input_type, static_cast <int64_t >(0 ));
114
+ if (!input_zp_val.has_value ())
115
+ op->emitError (" Failed to create input zero-point tensor for RescaleOp." );
116
+
117
+ const auto output_zp_val = tosa::createZeroPointTensor (
118
+ rewriter, op->getLoc (), output_type, output_zp);
119
+ if (!output_zp_val.has_value ())
120
+ op->emitError (" Failed to create output zero-point tensor for RescaleOp." );
121
+
99
122
if (auto weight_per_tensor_qtype =
100
123
dyn_cast<mlir::quant::UniformQuantizedType>(
101
124
weight_type.getElementType ())) {
@@ -107,7 +130,11 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
107
130
108
131
double op_tensor_scale = (input_scale * weight_scale) / output_scale;
109
132
110
- computeMultiplierAndShift (op_tensor_scale, multiplier, shift, scale_width);
133
+ if (!computeMultiplierAndShift (op_tensor_scale, multiplier, shift,
134
+ scale_width))
135
+ op->emitError (
136
+ " buildRescaleOpConvOutput: shift must be in the range 2 <= shift <= "
137
+ " 62" );
111
138
112
139
Value multiplier_val =
113
140
buildRescaleMultiplier (scale32, rewriter, op, {multiplier});
@@ -117,10 +144,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
117
144
118
145
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
119
146
rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
120
- shift_val, rewriter.getI32IntegerAttr (0 ),
121
- rewriter.getI32IntegerAttr (output_zp), rewriter.getBoolAttr (scale32),
122
- rewriter.getBoolAttr (true ), rewriter.getBoolAttr (false ),
123
- rewriter.getBoolAttr (input_unsigned),
147
+ shift_val, input_zp_val.value (), output_zp_val.value (),
148
+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (" DOUBLE_ROUND" ),
149
+ rewriter.getBoolAttr (false ), rewriter.getBoolAttr (input_unsigned),
124
150
rewriter.getBoolAttr (output_unsigned));
125
151
126
152
return rescale_op.getResult ();
@@ -136,17 +162,16 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
136
162
weight_per_channel_qtype.getScales ().begin (),
137
163
weight_per_channel_qtype.getScales ().end ());
138
164
139
- int64_t output_zp = output_qtype.getZeroPoint ();
140
- double output_scale = output_qtype.getScale ();
141
-
142
165
for (double weight_scale : weight_scale_arr) {
143
166
int32_t multiplier;
144
167
int32_t shift;
145
168
146
169
double op_channel_scale = (input_scale * weight_scale) / output_scale;
147
170
148
- computeMultiplierAndShift (op_channel_scale, multiplier, shift,
149
- scale_width);
171
+ if (!computeMultiplierAndShift (op_channel_scale, multiplier, shift, 32 ))
172
+ op->emitError (
173
+ " buildRescaleOpConvOutput: shift must be in the range 2 <= shift "
174
+ " <= 62" );
150
175
151
176
multiplier_arr.push_back (multiplier);
152
177
shift_arr.push_back (static_cast <int8_t >(shift));
@@ -161,10 +186,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
161
186
162
187
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
163
188
rewriter, op->getLoc (), output_type, conv_val, multiplier_val,
164
- shift_val, rewriter.getI32IntegerAttr (0 ),
165
- rewriter.getI32IntegerAttr (output_zp), rewriter.getBoolAttr (scale32),
166
- rewriter.getBoolAttr (true ), rewriter.getBoolAttr (true ),
167
- rewriter.getBoolAttr (input_unsigned),
189
+ shift_val, input_zp_val.value (), output_zp_val.value (),
190
+ rewriter.getBoolAttr (scale32), rewriter.getStringAttr (" DOUBLE_ROUND" ),
191
+ rewriter.getBoolAttr (true ), rewriter.getBoolAttr (input_unsigned),
168
192
rewriter.getBoolAttr (output_unsigned));
169
193
170
194
return rescale_op.getResult ();
0 commit comments