@@ -89,9 +89,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
89
89
if (elementType.isF32 ())
90
90
return f32 ;
91
91
if (elementType.getIntOrFloatBitWidth () < 32 )
92
- return rewriter. create < arith::TruncFOp>( loc, desType, f32 );
92
+ return arith::TruncFOp::create (rewriter, loc, desType, f32 );
93
93
if (elementType.getIntOrFloatBitWidth () > 32 )
94
- return rewriter. create < arith::ExtFOp>( loc, desType, f32 );
94
+ return arith::ExtFOp::create (rewriter, loc, desType, f32 );
95
95
llvm_unreachable (" The only 32-bit float type is f32" );
96
96
}
97
97
@@ -113,26 +113,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
113
113
Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
114
114
VectorType extResType = VectorType::get (2 , rewriter.getF32Type ());
115
115
if (!inVecType) {
116
- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
116
+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
117
117
loc, rewriter.getF32Type (), in, 0 );
118
118
Value result = castF32To (outElemType, asFloat, loc, rewriter);
119
119
rewriter.replaceOp (op, result);
120
120
return success ();
121
121
}
122
122
int64_t numElements = inVecType.getNumElements ();
123
123
124
- Value zero = rewriter. create < arith::ConstantOp>(
124
+ Value zero = arith::ConstantOp::create (rewriter,
125
125
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
126
126
VectorType outType = cast<VectorType>(op.getOut ().getType ());
127
127
128
128
if (inVecType.getShape ().empty ()) {
129
129
Value zerodSplat =
130
130
rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
131
131
Value scalarIn =
132
- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
132
+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
133
133
Value scalarExt =
134
- rewriter. create < arith::ExtFOp>( loc, outElemType, scalarIn);
135
- Value result = rewriter. create < vector::InsertOp>( loc, scalarExt, zerodSplat,
134
+ arith::ExtFOp::create (rewriter, loc, outElemType, scalarIn);
135
+ Value result = vector::InsertOp::create (rewriter, loc, scalarExt, zerodSplat,
136
136
ArrayRef<int64_t >{});
137
137
rewriter.replaceOp (op, result);
138
138
return success ();
@@ -145,32 +145,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
145
145
if (inVecType.getRank () > 1 ) {
146
146
inVecType = VectorType::get (SmallVector<int64_t >{numElements},
147
147
inVecType.getElementType ());
148
- in = rewriter. create < vector::ShapeCastOp>( loc, inVecType, in);
148
+ in = vector::ShapeCastOp::create (rewriter, loc, inVecType, in);
149
149
}
150
150
151
151
for (int64_t i = 0 ; i < numElements; i += 4 ) {
152
152
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
153
- Value inSlice = rewriter. create < vector::ExtractStridedSliceOp>(
153
+ Value inSlice = vector::ExtractStridedSliceOp::create (rewriter,
154
154
loc, in, i, elemsThisOp, 1 );
155
155
for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
156
156
if (i + j + 1 < numElements) { // Convert two 8-bit elements
157
- Value asFloats = rewriter. create < amdgpu::ExtPackedFp8Op>(
157
+ Value asFloats = amdgpu::ExtPackedFp8Op::create (rewriter,
158
158
loc, extResType, inSlice, j / 2 );
159
159
Type desType = VectorType::get (2 , outElemType);
160
160
Value asType = castF32To (desType, asFloats, loc, rewriter);
161
- result = rewriter. create < vector::InsertStridedSliceOp>(
161
+ result = vector::InsertStridedSliceOp::create (rewriter,
162
162
loc, asType, result, i + j, 1 );
163
163
} else { // Convert a 8-bit element
164
- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
164
+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
165
165
loc, rewriter.getF32Type (), inSlice, j / 2 * 2 );
166
166
Value asType = castF32To (outElemType, asFloat, loc, rewriter);
167
- result = rewriter. create < vector::InsertOp>( loc, asType, result, i + j);
167
+ result = vector::InsertOp::create (rewriter, loc, asType, result, i + j);
168
168
}
169
169
}
170
170
}
171
171
172
172
if (inVecType.getRank () != outType.getRank ()) {
173
- result = rewriter. create < vector::ShapeCastOp>( loc, outType, result);
173
+ result = vector::ShapeCastOp::create (rewriter, loc, outType, result);
174
174
}
175
175
176
176
rewriter.replaceOp (op, result);
@@ -182,9 +182,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
182
182
if (type.isF32 ())
183
183
return value;
184
184
if (type.getIntOrFloatBitWidth () < 32 )
185
- return rewriter. create < arith::ExtFOp>( loc, rewriter.getF32Type (), value);
185
+ return arith::ExtFOp::create (rewriter, loc, rewriter.getF32Type (), value);
186
186
if (type.getIntOrFloatBitWidth () > 32 )
187
- return rewriter. create < arith::TruncFOp>( loc, rewriter.getF32Type (), value);
187
+ return arith::TruncFOp::create (rewriter, loc, rewriter.getF32Type (), value);
188
188
llvm_unreachable (" The only 32-bit float type is f32" );
189
189
}
190
190
@@ -224,13 +224,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
224
224
loc, arith::CmpFPredicate::OEQ, source, negInf);
225
225
Value isNan = rewriter.createOrFold <arith::CmpFOp>(
226
226
loc, arith::CmpFPredicate::UNO, source, source);
227
- Value isNonFinite = rewriter. create < arith::OrIOp>(
228
- loc, rewriter. create < arith::OrIOp>( loc, isInf, isNegInf), isNan);
227
+ Value isNonFinite = arith::OrIOp::create (rewriter,
228
+ loc, arith::OrIOp::create (rewriter, loc, isInf, isNegInf), isNan);
229
229
230
- Value clampedBelow = rewriter. create < arith::MaximumFOp>( loc, source, minCst);
231
- Value clamped = rewriter. create < arith::MinimumFOp>( loc, clampedBelow, maxCst);
230
+ Value clampedBelow = arith::MaximumFOp::create (rewriter, loc, source, minCst);
231
+ Value clamped = arith::MinimumFOp::create (rewriter, loc, clampedBelow, maxCst);
232
232
Value res =
233
- rewriter. create < arith::SelectOp>( loc, isNonFinite, source, clamped);
233
+ arith::SelectOp::create (rewriter, loc, isNonFinite, source, clamped);
234
234
return res;
235
235
}
236
236
@@ -264,24 +264,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
264
264
VectorType truncResType = VectorType::get (4 , outElemType);
265
265
if (!inVectorTy) {
266
266
Value asFloat = castToF32 (in, loc, rewriter);
267
- Value asF8s = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
267
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
268
268
loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
269
269
/* existing=*/ nullptr );
270
- Value result = rewriter. create < vector::ExtractOp>( loc, asF8s, 0 );
270
+ Value result = vector::ExtractOp::create (rewriter, loc, asF8s, 0 );
271
271
rewriter.replaceOp (op, result);
272
272
return success ();
273
273
}
274
274
275
275
int64_t numElements = outVecType.getNumElements ();
276
- Value zero = rewriter. create < arith::ConstantOp>(
276
+ Value zero = arith::ConstantOp::create (rewriter,
277
277
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
278
278
if (outVecType.getShape ().empty ()) {
279
279
Value scalarIn =
280
- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
280
+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
281
281
// Recurse to send the 0-D vector case to the 1-D vector case
282
282
Value scalarTrunc =
283
- rewriter. create < arith::TruncFOp>( loc, outElemType, scalarIn);
284
- Value result = rewriter. create < vector::InsertOp>( loc, scalarTrunc, zero,
283
+ arith::TruncFOp::create (rewriter, loc, outElemType, scalarIn);
284
+ Value result = vector::InsertOp::create (rewriter, loc, scalarTrunc, zero,
285
285
ArrayRef<int64_t >{});
286
286
rewriter.replaceOp (op, result);
287
287
return success ();
@@ -294,32 +294,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
294
294
if (inVectorTy.getRank () > 1 ) {
295
295
inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
296
296
inVectorTy.getElementType ());
297
- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
297
+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
298
298
}
299
299
300
300
for (int64_t i = 0 ; i < numElements; i += 4 ) {
301
301
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
302
302
Value thisResult = nullptr ;
303
303
for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
304
- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i + j);
304
+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i + j);
305
305
Value asFloatA = castToF32 (elemA, loc, rewriter);
306
306
Value asFloatB = nullptr ;
307
307
if (j + 1 < elemsThisOp) {
308
- Value elemB = rewriter. create < vector::ExtractOp>( loc, in, i + j + 1 );
308
+ Value elemB = vector::ExtractOp::create (rewriter, loc, in, i + j + 1 );
309
309
asFloatB = castToF32 (elemB, loc, rewriter);
310
310
}
311
- thisResult = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
311
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
312
312
loc, truncResType, asFloatA, asFloatB, j / 2 , thisResult);
313
313
}
314
314
if (elemsThisOp < 4 )
315
- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
315
+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
316
316
loc, thisResult, 0 , elemsThisOp, 1 );
317
- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
317
+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
318
318
result, i, 1 );
319
319
}
320
320
321
321
if (inVectorTy.getRank () != outVecType.getRank ()) {
322
- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
322
+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
323
323
}
324
324
325
325
rewriter.replaceOp (op, result);
@@ -347,10 +347,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
347
347
348
348
// Handle the case where input type is not a vector type
349
349
if (!inVectorTy) {
350
- auto sourceB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
350
+ auto sourceB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
351
351
Value asF16s =
352
- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, in, sourceB);
353
- Value result = rewriter. create < vector::ExtractOp>( loc, asF16s, 0 );
352
+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, in, sourceB);
353
+ Value result = vector::ExtractOp::create (rewriter, loc, asF16s, 0 );
354
354
rewriter.replaceOp (op, result);
355
355
return success ();
356
356
}
@@ -362,33 +362,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
362
362
if (inVectorTy.getRank () > 1 ) {
363
363
inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
364
364
inVectorTy.getElementType ());
365
- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
365
+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
366
366
}
367
367
368
368
// Handle the vector case. We also handle the (uncommon) case where the vector
369
369
// length is odd
370
370
for (int64_t i = 0 ; i < numElements; i += 2 ) {
371
371
int64_t elemsThisOp = std::min (numElements, i + 2 ) - i;
372
372
Value thisResult = nullptr ;
373
- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i);
374
- Value elemB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
373
+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i);
374
+ Value elemB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
375
375
376
376
if (elemsThisOp == 2 ) {
377
- elemB = rewriter. create < vector::ExtractOp>( loc, in, i + 1 );
377
+ elemB = vector::ExtractOp::create (rewriter, loc, in, i + 1 );
378
378
}
379
379
380
380
thisResult =
381
- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, elemA, elemB);
381
+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, elemA, elemB);
382
382
// Place back the truncated result into the possibly larger vector. If we
383
383
// are operating on a size 2 vector, these operations should be folded away
384
- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
384
+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
385
385
loc, thisResult, 0 , elemsThisOp, 1 );
386
- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
386
+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
387
387
result, i, 1 );
388
388
}
389
389
390
390
if (inVectorTy.getRank () != outVecType.getRank ()) {
391
- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
391
+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
392
392
}
393
393
394
394
rewriter.replaceOp (op, result);
0 commit comments