@@ -398,6 +398,18 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
398
398
return {};
399
399
}
400
400
401
+ // / Converts an IntegerAttr to have the specified type if needed.
402
+ // / This handles cases where constant attributes have a different type than the
403
+ // / target element type. If the input attribute is not an IntegerAttr or already
404
+ // / has the correct type, returns it unchanged.
405
+ static Attribute convertIntegerAttr (Attribute attr, Type expectedType) {
406
+ if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
407
+ if (intAttr.getType () != expectedType)
408
+ return IntegerAttr::get (expectedType, intAttr.getInt ());
409
+ }
410
+ return attr;
411
+ }
412
+
401
413
// ===----------------------------------------------------------------------===//
402
414
// CombiningKindAttr
403
415
// ===----------------------------------------------------------------------===//
@@ -2464,8 +2476,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2464
2476
return {};
2465
2477
}
2466
2478
2479
+ // / Fold vector.from_elements to a constant when all operands are constants.
2480
+ // / Example:
2481
+ // / %c1 = arith.constant 1 : i32
2482
+ // / %c2 = arith.constant 2 : i32
2483
+ // / %v = vector.from_elements %c1, %c2 : vector<2xi32>
2484
+ // / =>
2485
+ // / %v = arith.constant dense<[1, 2]> : vector<2xi32>
2486
+ // /
2487
+ static OpFoldResult foldFromElementsToConstant (FromElementsOp fromElementsOp,
2488
+ ArrayRef<Attribute> elements) {
2489
+ if (llvm::any_of (elements, [](Attribute attr) { return !attr; }))
2490
+ return {};
2491
+
2492
+ auto destVecType = fromElementsOp.getDest ().getType ();
2493
+ auto destEltType = destVecType.getElementType ();
2494
+ // Constant attributes might have a different type than the return type.
2495
+ // Convert them before creating the dense elements attribute.
2496
+ auto convertedElements = llvm::map_to_vector (elements, [&](Attribute attr) {
2497
+ return convertIntegerAttr (attr, destEltType);
2498
+ });
2499
+
2500
+ return DenseElementsAttr::get (destVecType, convertedElements);
2501
+ }
2502
+
2467
2503
OpFoldResult FromElementsOp::fold (FoldAdaptor adaptor) {
2468
- return foldFromElementsToElements (*this );
2504
+ if (auto res = foldFromElementsToElements (*this ))
2505
+ return res;
2506
+ if (auto res = foldFromElementsToConstant (*this , adaptor.getElements ()))
2507
+ return res;
2508
+
2509
+ return {};
2469
2510
}
2470
2511
2471
2512
// / Rewrite a vector.from_elements into a vector.splat if all elements are the
@@ -3332,17 +3373,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
3332
3373
3333
3374
// / Converts the expected type to an IntegerAttr if there's
3334
3375
// / a mismatch.
3335
- auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3336
- if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3337
- if (intAttr.getType () != expectedType)
3338
- return IntegerAttr::get (expectedType, intAttr.getInt ());
3339
- }
3340
- return attr;
3341
- };
3342
-
3343
- // The `convertIntegerAttr` method specifically handles the case
3344
- // for `llvm.mlir.constant` which can hold an attribute with a
3345
- // different type than the return type.
3346
3376
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3347
3377
for (auto value : denseSource.getValues <Attribute>())
3348
3378
insertedValues.push_back (convertIntegerAttr (value, destEltType));
0 commit comments