@@ -115,9 +115,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
115
115
if (elementType.isF32 ())
116
116
return f32 ;
117
117
if (elementType.getIntOrFloatBitWidth () < 32 )
118
- return rewriter. create < arith::TruncFOp>( loc, desType, f32 );
118
+ return arith::TruncFOp::create (rewriter, loc, desType, f32 );
119
119
if (elementType.getIntOrFloatBitWidth () > 32 )
120
- return rewriter. create < arith::ExtFOp>( loc, desType, f32 );
120
+ return arith::ExtFOp::create (rewriter, loc, desType, f32 );
121
121
llvm_unreachable (" The only 32-bit float type is f32" );
122
122
}
123
123
@@ -139,26 +139,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
139
139
Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
140
140
VectorType extResType = VectorType::get (2 , rewriter.getF32Type ());
141
141
if (!inVecType) {
142
- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
142
+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
143
143
loc, rewriter.getF32Type (), in, 0 );
144
144
Value result = castF32To (outElemType, asFloat, loc, rewriter);
145
145
rewriter.replaceOp (op, result);
146
146
return success ();
147
147
}
148
148
int64_t numElements = inVecType.getNumElements ();
149
149
150
- Value zero = rewriter. create < arith::ConstantOp>(
150
+ Value zero = arith::ConstantOp::create (rewriter,
151
151
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
152
152
VectorType outType = cast<VectorType>(op.getOut ().getType ());
153
153
154
154
if (inVecType.getShape ().empty ()) {
155
155
Value zerodSplat =
156
156
rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
157
157
Value scalarIn =
158
- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
158
+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
159
159
Value scalarExt =
160
- rewriter. create < arith::ExtFOp>( loc, outElemType, scalarIn);
161
- Value result = rewriter. create < vector::InsertOp>( loc, scalarExt, zerodSplat,
160
+ arith::ExtFOp::create (rewriter, loc, outElemType, scalarIn);
161
+ Value result = vector::InsertOp::create (rewriter, loc, scalarExt, zerodSplat,
162
162
ArrayRef<int64_t >{});
163
163
rewriter.replaceOp (op, result);
164
164
return success ();
@@ -171,32 +171,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
171
171
if (inVecType.getRank () > 1 ) {
172
172
inVecType = VectorType::get (SmallVector<int64_t >{numElements},
173
173
inVecType.getElementType ());
174
- in = rewriter. create < vector::ShapeCastOp>( loc, inVecType, in);
174
+ in = vector::ShapeCastOp::create (rewriter, loc, inVecType, in);
175
175
}
176
176
177
177
for (int64_t i = 0 ; i < numElements; i += 4 ) {
178
178
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
179
- Value inSlice = rewriter. create < vector::ExtractStridedSliceOp>(
179
+ Value inSlice = vector::ExtractStridedSliceOp::create (rewriter,
180
180
loc, in, i, elemsThisOp, 1 );
181
181
for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
182
182
if (i + j + 1 < numElements) { // Convert two 8-bit elements
183
- Value asFloats = rewriter. create < amdgpu::ExtPackedFp8Op>(
183
+ Value asFloats = amdgpu::ExtPackedFp8Op::create (rewriter,
184
184
loc, extResType, inSlice, j / 2 );
185
185
Type desType = VectorType::get (2 , outElemType);
186
186
Value asType = castF32To (desType, asFloats, loc, rewriter);
187
- result = rewriter. create < vector::InsertStridedSliceOp>(
187
+ result = vector::InsertStridedSliceOp::create (rewriter,
188
188
loc, asType, result, i + j, 1 );
189
189
} else { // Convert a 8-bit element
190
- Value asFloat = rewriter. create < amdgpu::ExtPackedFp8Op>(
190
+ Value asFloat = amdgpu::ExtPackedFp8Op::create (rewriter,
191
191
loc, rewriter.getF32Type (), inSlice, j / 2 * 2 );
192
192
Value asType = castF32To (outElemType, asFloat, loc, rewriter);
193
- result = rewriter. create < vector::InsertOp>( loc, asType, result, i + j);
193
+ result = vector::InsertOp::create (rewriter, loc, asType, result, i + j);
194
194
}
195
195
}
196
196
}
197
197
198
198
if (inVecType.getRank () != outType.getRank ()) {
199
- result = rewriter. create < vector::ShapeCastOp>( loc, outType, result);
199
+ result = vector::ShapeCastOp::create (rewriter, loc, outType, result);
200
200
}
201
201
202
202
rewriter.replaceOp (op, result);
@@ -208,9 +208,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
208
208
if (type.isF32 ())
209
209
return value;
210
210
if (type.getIntOrFloatBitWidth () < 32 )
211
- return rewriter. create < arith::ExtFOp>( loc, rewriter.getF32Type (), value);
211
+ return arith::ExtFOp::create (rewriter, loc, rewriter.getF32Type (), value);
212
212
if (type.getIntOrFloatBitWidth () > 32 )
213
- return rewriter. create < arith::TruncFOp>( loc, rewriter.getF32Type (), value);
213
+ return arith::TruncFOp::create (rewriter, loc, rewriter.getF32Type (), value);
214
214
llvm_unreachable (" The only 32-bit float type is f32" );
215
215
}
216
216
@@ -250,13 +250,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
250
250
loc, arith::CmpFPredicate::OEQ, source, negInf);
251
251
Value isNan = rewriter.createOrFold <arith::CmpFOp>(
252
252
loc, arith::CmpFPredicate::UNO, source, source);
253
- Value isNonFinite = rewriter. create < arith::OrIOp>(
254
- loc, rewriter. create < arith::OrIOp>( loc, isInf, isNegInf), isNan);
253
+ Value isNonFinite = arith::OrIOp::create (rewriter,
254
+ loc, arith::OrIOp::create (rewriter, loc, isInf, isNegInf), isNan);
255
255
256
- Value clampedBelow = rewriter. create < arith::MaximumFOp>( loc, source, minCst);
257
- Value clamped = rewriter. create < arith::MinimumFOp>( loc, clampedBelow, maxCst);
256
+ Value clampedBelow = arith::MaximumFOp::create (rewriter, loc, source, minCst);
257
+ Value clamped = arith::MinimumFOp::create (rewriter, loc, clampedBelow, maxCst);
258
258
Value res =
259
- rewriter. create < arith::SelectOp>( loc, isNonFinite, source, clamped);
259
+ arith::SelectOp::create (rewriter, loc, isNonFinite, source, clamped);
260
260
return res;
261
261
}
262
262
@@ -290,24 +290,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
290
290
VectorType truncResType = VectorType::get (4 , outElemType);
291
291
if (!inVectorTy) {
292
292
Value asFloat = castToF32 (in, loc, rewriter);
293
- Value asF8s = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
293
+ Value asF8s = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
294
294
loc, truncResType, asFloat, /* sourceB=*/ nullptr , 0 ,
295
295
/* existing=*/ nullptr );
296
- Value result = rewriter. create < vector::ExtractOp>( loc, asF8s, 0 );
296
+ Value result = vector::ExtractOp::create (rewriter, loc, asF8s, 0 );
297
297
rewriter.replaceOp (op, result);
298
298
return success ();
299
299
}
300
300
301
301
int64_t numElements = outVecType.getNumElements ();
302
- Value zero = rewriter. create < arith::ConstantOp>(
302
+ Value zero = arith::ConstantOp::create (rewriter,
303
303
loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
304
304
if (outVecType.getShape ().empty ()) {
305
305
Value scalarIn =
306
- rewriter. create < vector::ExtractOp>( loc, in, ArrayRef<int64_t >{});
306
+ vector::ExtractOp::create (rewriter, loc, in, ArrayRef<int64_t >{});
307
307
// Recurse to send the 0-D vector case to the 1-D vector case
308
308
Value scalarTrunc =
309
- rewriter. create < arith::TruncFOp>( loc, outElemType, scalarIn);
310
- Value result = rewriter. create < vector::InsertOp>( loc, scalarTrunc, zero,
309
+ arith::TruncFOp::create (rewriter, loc, outElemType, scalarIn);
310
+ Value result = vector::InsertOp::create (rewriter, loc, scalarTrunc, zero,
311
311
ArrayRef<int64_t >{});
312
312
rewriter.replaceOp (op, result);
313
313
return success ();
@@ -320,32 +320,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
320
320
if (inVectorTy.getRank () > 1 ) {
321
321
inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
322
322
inVectorTy.getElementType ());
323
- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
323
+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
324
324
}
325
325
326
326
for (int64_t i = 0 ; i < numElements; i += 4 ) {
327
327
int64_t elemsThisOp = std::min (numElements, i + 4 ) - i;
328
328
Value thisResult = nullptr ;
329
329
for (int64_t j = 0 ; j < elemsThisOp; j += 2 ) {
330
- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i + j);
330
+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i + j);
331
331
Value asFloatA = castToF32 (elemA, loc, rewriter);
332
332
Value asFloatB = nullptr ;
333
333
if (j + 1 < elemsThisOp) {
334
- Value elemB = rewriter. create < vector::ExtractOp>( loc, in, i + j + 1 );
334
+ Value elemB = vector::ExtractOp::create (rewriter, loc, in, i + j + 1 );
335
335
asFloatB = castToF32 (elemB, loc, rewriter);
336
336
}
337
- thisResult = rewriter. create < amdgpu::PackedTrunc2xFp8Op>(
337
+ thisResult = amdgpu::PackedTrunc2xFp8Op::create (rewriter,
338
338
loc, truncResType, asFloatA, asFloatB, j / 2 , thisResult);
339
339
}
340
340
if (elemsThisOp < 4 )
341
- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
341
+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
342
342
loc, thisResult, 0 , elemsThisOp, 1 );
343
- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
343
+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
344
344
result, i, 1 );
345
345
}
346
346
347
347
if (inVectorTy.getRank () != outVecType.getRank ()) {
348
- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
348
+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
349
349
}
350
350
351
351
rewriter.replaceOp (op, result);
@@ -373,10 +373,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
373
373
374
374
// Handle the case where input type is not a vector type
375
375
if (!inVectorTy) {
376
- auto sourceB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
376
+ auto sourceB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
377
377
Value asF16s =
378
- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, in, sourceB);
379
- Value result = rewriter. create < vector::ExtractOp>( loc, asF16s, 0 );
378
+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, in, sourceB);
379
+ Value result = vector::ExtractOp::create (rewriter, loc, asF16s, 0 );
380
380
rewriter.replaceOp (op, result);
381
381
return success ();
382
382
}
@@ -388,33 +388,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
388
388
if (inVectorTy.getRank () > 1 ) {
389
389
inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
390
390
inVectorTy.getElementType ());
391
- in = rewriter. create < vector::ShapeCastOp>( loc, inVectorTy, in);
391
+ in = vector::ShapeCastOp::create (rewriter, loc, inVectorTy, in);
392
392
}
393
393
394
394
// Handle the vector case. We also handle the (uncommon) case where the vector
395
395
// length is odd
396
396
for (int64_t i = 0 ; i < numElements; i += 2 ) {
397
397
int64_t elemsThisOp = std::min (numElements, i + 2 ) - i;
398
398
Value thisResult = nullptr ;
399
- Value elemA = rewriter. create < vector::ExtractOp>( loc, in, i);
400
- Value elemB = rewriter. create < LLVM::PoisonOp>( loc, rewriter.getF32Type ());
399
+ Value elemA = vector::ExtractOp::create (rewriter, loc, in, i);
400
+ Value elemB = LLVM::PoisonOp::create (rewriter, loc, rewriter.getF32Type ());
401
401
402
402
if (elemsThisOp == 2 ) {
403
- elemB = rewriter. create < vector::ExtractOp>( loc, in, i + 1 );
403
+ elemB = vector::ExtractOp::create (rewriter, loc, in, i + 1 );
404
404
}
405
405
406
406
thisResult =
407
- rewriter. create < ROCDL::CvtPkRtz>( loc, truncResType, elemA, elemB);
407
+ ROCDL::CvtPkRtz::create (rewriter, loc, truncResType, elemA, elemB);
408
408
// Place back the truncated result into the possibly larger vector. If we
409
409
// are operating on a size 2 vector, these operations should be folded away
410
- thisResult = rewriter. create < vector::ExtractStridedSliceOp>(
410
+ thisResult = vector::ExtractStridedSliceOp::create (rewriter,
411
411
loc, thisResult, 0 , elemsThisOp, 1 );
412
- result = rewriter. create < vector::InsertStridedSliceOp>( loc, thisResult,
412
+ result = vector::InsertStridedSliceOp::create (rewriter, loc, thisResult,
413
413
result, i, 1 );
414
414
}
415
415
416
416
if (inVectorTy.getRank () != outVecType.getRank ()) {
417
- result = rewriter. create < vector::ShapeCastOp>( loc, outVecType, result);
417
+ result = vector::ShapeCastOp::create (rewriter, loc, outVecType, result);
418
418
}
419
419
420
420
rewriter.replaceOp (op, result);
0 commit comments