Skip to content

Commit 393a75e

Browse files
yangtetrisYang Bai
andauthored
[mlir][Vector] Add constant folding for vector.from_elements operation (#145849)
### Summary This PR adds a new folding pattern for **vector.from_elements** that canonicalizes it to **arith.constant** when all input operands are constants. ### Implementation Details **Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to access **pre-computed** constant attributes, avoiding redundant pattern matching on operands. ### Example Transformation ``` Before: %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 %c2_i32 = arith.constant 2 : i32 %c3_i32 = arith.constant 3 : i32 %v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32> After: %v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32> ``` --------- Co-authored-by: Yang Bai <yangb@nvidia.com>
1 parent 0a69c83 commit 393a75e

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,18 @@ std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
398398
return {};
399399
}
400400

401+
/// Converts an IntegerAttr to have the specified type if needed.
402+
/// This handles cases where constant attributes have a different type than the
403+
/// target element type. If the input attribute is not an IntegerAttr or already
404+
/// has the correct type, returns it unchanged.
405+
static Attribute convertIntegerAttr(Attribute attr, Type expectedType) {
406+
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
407+
if (intAttr.getType() != expectedType)
408+
return IntegerAttr::get(expectedType, intAttr.getInt());
409+
}
410+
return attr;
411+
}
412+
401413
//===----------------------------------------------------------------------===//
402414
// CombiningKindAttr
403415
//===----------------------------------------------------------------------===//
@@ -2464,8 +2476,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
24642476
return {};
24652477
}
24662478

2479+
/// Fold vector.from_elements to a constant when all operands are constants.
2480+
/// Example:
2481+
/// %c1 = arith.constant 1 : i32
2482+
/// %c2 = arith.constant 2 : i32
2483+
/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2484+
/// =>
2485+
/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2486+
///
2487+
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2488+
ArrayRef<Attribute> elements) {
2489+
if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
2490+
return {};
2491+
2492+
auto destVecType = fromElementsOp.getDest().getType();
2493+
auto destEltType = destVecType.getElementType();
2494+
// Constant attributes might have a different type than the return type.
2495+
// Convert them before creating the dense elements attribute.
2496+
auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2497+
return convertIntegerAttr(attr, destEltType);
2498+
});
2499+
2500+
return DenseElementsAttr::get(destVecType, convertedElements);
2501+
}
2502+
24672503
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2468-
return foldFromElementsToElements(*this);
2504+
if (auto res = foldFromElementsToElements(*this))
2505+
return res;
2506+
if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2507+
return res;
2508+
2509+
return {};
24692510
}
24702511

24712512
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
@@ -3332,17 +3373,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
33323373

33333374
/// Converts the expected type to an IntegerAttr if there's
33343375
/// a mismatch.
3335-
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
3336-
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3337-
if (intAttr.getType() != expectedType)
3338-
return IntegerAttr::get(expectedType, intAttr.getInt());
3339-
}
3340-
return attr;
3341-
};
3342-
3343-
// The `convertIntegerAttr` method specifically handles the case
3344-
// for `llvm.mlir.constant` which can hold an attribute with a
3345-
// different type than the return type.
33463376
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
33473377
for (auto value : denseSource.getValues<Attribute>())
33483378
insertedValues.push_back(convertIntegerAttr(value, destEltType));

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3075,6 +3075,33 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
30753075

30763076
// -----
30773077

3078+
// CHECK-LABEL: func @from_elements_all_elements_constant(
3079+
func.func @from_elements_all_elements_constant() -> vector<2x2xi32> {
3080+
%c0_i32 = arith.constant 0 : i32
3081+
%c1_i32 = arith.constant 1 : i32
3082+
%c2_i32 = arith.constant 2 : i32
3083+
%c3_i32 = arith.constant 3 : i32
3084+
// CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
3085+
%res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
3086+
// CHECK: return %[[RES]]
3087+
return %res : vector<2x2xi32>
3088+
}
3089+
3090+
// -----
3091+
3092+
// CHECK-LABEL: func @from_elements_partial_elements_constant(
3093+
// CHECK-SAME: %[[A:.*]]: f32
3094+
func.func @from_elements_partial_elements_constant(%arg0: f32) -> vector<2xf32> {
3095+
// CHECK: %[[C:.*]] = arith.constant 1.000000e+00 : f32
3096+
%c = arith.constant 1.0 : f32
3097+
// CHECK: %[[RES:.*]] = vector.from_elements %[[A]], %[[C]] : vector<2xf32>
3098+
%res = vector.from_elements %arg0, %c : vector<2xf32>
3099+
// CHECK: return %[[RES]]
3100+
return %res : vector<2xf32>
3101+
}
3102+
3103+
// -----
3104+
30783105
// CHECK-LABEL: func @vector_insert_const_regression(
30793106
// CHECK: llvm.mlir.undef
30803107
// CHECK: vector.insert

0 commit comments

Comments
 (0)