Skip to content

Commit 274152c

Browse files
authored
[mlir][vector][spirv] Lower vector.to_elements to SPIR-V (#146618)
Implement `vector.to_elements` lowering to SPIR-V. Fixes: #145929
1 parent a880c8e commit 274152c

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
10221022
}
10231023
};
10241024

1025+
struct VectorToElementOpConvert final
1026+
: OpConversionPattern<vector::ToElementsOp> {
1027+
using OpConversionPattern::OpConversionPattern;
1028+
1029+
LogicalResult
1030+
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1031+
ConversionPatternRewriter &rewriter) const override {
1032+
1033+
SmallVector<Value> results(toElementsOp->getNumResults());
1034+
Location loc = toElementsOp.getLoc();
1035+
1036+
// Input vectors of size 1 are converted to scalars by the type converter.
1037+
// We cannot use `spirv::CompositeExtractOp` directly in this case.
1038+
// For a scalar source, the result is just the scalar itself.
1039+
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
1040+
results[0] = adaptor.getSource();
1041+
rewriter.replaceOp(toElementsOp, results);
1042+
return success();
1043+
}
1044+
1045+
Type srcElementType = toElementsOp.getElements().getType().front();
1046+
Type elementType = getTypeConverter()->convertType(srcElementType);
1047+
if (!elementType)
1048+
return rewriter.notifyMatchFailure(
1049+
toElementsOp,
1050+
llvm::formatv("failed to convert element type '{0}' to SPIR-V",
1051+
srcElementType));
1052+
1053+
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
1054+
// Create an CompositeExtract operation only for results that are not
1055+
// dead.
1056+
if (element.use_empty())
1057+
continue;
1058+
1059+
Value result = rewriter.create<spirv::CompositeExtractOp>(
1060+
loc, elementType, adaptor.getSource(),
1061+
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
1062+
results[idx] = result;
1063+
}
1064+
1065+
rewriter.replaceOp(toElementsOp, results);
1066+
return success();
1067+
}
1068+
};
1069+
10251070
} // namespace
10261071
#define CL_INT_MAX_MIN_OPS \
10271072
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns(
10391084
VectorExtractElementOpConvert, VectorExtractOpConvert,
10401085
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
10411086
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042-
VectorInsertElementOpConvert, VectorInsertOpConvert,
1043-
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1087+
VectorToElementOpConvert, VectorInsertElementOpConvert,
1088+
VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
10441089
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
10451090
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
10461091
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,41 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
246246

247247
// -----
248248

249+
// CHECK-LABEL: func.func @to_elements_one_element
250+
// CHECK-SAME: %[[A:.*]]: vector<1xf32>)
251+
// CHECK: %[[ELEM0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1xf32> to f32
252+
// CHECK: return %[[ELEM0]] : f32
253+
func.func @to_elements_one_element(%a: vector<1xf32>) -> (f32) {
254+
%0:1 = vector.to_elements %a : vector<1xf32>
255+
return %0#0 : f32
256+
}
257+
258+
// CHECK-LABEL: func.func @to_elements_no_dead_elements
259+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
260+
// CHECK: %[[ELEM0:.*]] = spirv.CompositeExtract %[[A]][0 : i32] : vector<4xf32>
261+
// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
262+
// CHECK: %[[ELEM2:.*]] = spirv.CompositeExtract %[[A]][2 : i32] : vector<4xf32>
263+
// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
264+
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
265+
func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
266+
%0:4 = vector.to_elements %a : vector<4xf32>
267+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
268+
}
269+
270+
// CHECK-LABEL: func.func @to_elements_dead_elements
271+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
272+
// CHECK-NOT: spirv.CompositeExtract %[[A]][0 : i32]
273+
// CHECK: %[[ELEM1:.*]] = spirv.CompositeExtract %[[A]][1 : i32] : vector<4xf32>
274+
// CHECK-NOT: spirv.CompositeExtract %[[A]][2 : i32]
275+
// CHECK: %[[ELEM3:.*]] = spirv.CompositeExtract %[[A]][3 : i32] : vector<4xf32>
276+
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
277+
func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
278+
%0:4 = vector.to_elements %a : vector<4xf32>
279+
return %0#1, %0#3 : f32, f32
280+
}
281+
282+
// -----
283+
249284
// CHECK-LABEL: @from_elements_0d
250285
// CHECK-SAME: %[[ARG0:.+]]: f32
251286
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]

0 commit comments

Comments
 (0)