@@ -1022,6 +1022,51 @@ struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
1022
1022
}
1023
1023
};
1024
1024
1025
+ struct VectorToElementOpConvert final
1026
+ : OpConversionPattern<vector::ToElementsOp> {
1027
+ using OpConversionPattern::OpConversionPattern;
1028
+
1029
+ LogicalResult
1030
+ matchAndRewrite (vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1031
+ ConversionPatternRewriter &rewriter) const override {
1032
+
1033
+ SmallVector<Value> results (toElementsOp->getNumResults ());
1034
+ Location loc = toElementsOp.getLoc ();
1035
+
1036
+ // Input vectors of size 1 are converted to scalars by the type converter.
1037
+ // We cannot use `spirv::CompositeExtractOp` directly in this case.
1038
+ // For a scalar source, the result is just the scalar itself.
1039
+ if (isa<spirv::ScalarType>(adaptor.getSource ().getType ())) {
1040
+ results[0 ] = adaptor.getSource ();
1041
+ rewriter.replaceOp (toElementsOp, results);
1042
+ return success ();
1043
+ }
1044
+
1045
+ Type srcElementType = toElementsOp.getElements ().getType ().front ();
1046
+ Type elementType = getTypeConverter ()->convertType (srcElementType);
1047
+ if (!elementType)
1048
+ return rewriter.notifyMatchFailure (
1049
+ toElementsOp,
1050
+ llvm::formatv (" failed to convert element type '{0}' to SPIR-V" ,
1051
+ srcElementType));
1052
+
1053
+ for (auto [idx, element] : llvm::enumerate (toElementsOp.getElements ())) {
1054
+ // Create an CompositeExtract operation only for results that are not
1055
+ // dead.
1056
+ if (element.use_empty ())
1057
+ continue ;
1058
+
1059
+ Value result = rewriter.create <spirv::CompositeExtractOp>(
1060
+ loc, elementType, adaptor.getSource (),
1061
+ rewriter.getI32ArrayAttr ({static_cast <int32_t >(idx)}));
1062
+ results[idx] = result;
1063
+ }
1064
+
1065
+ rewriter.replaceOp (toElementsOp, results);
1066
+ return success ();
1067
+ }
1068
+ };
1069
+
1025
1070
} // namespace
1026
1071
#define CL_INT_MAX_MIN_OPS \
1027
1072
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
@@ -1039,8 +1084,8 @@ void mlir::populateVectorToSPIRVPatterns(
1039
1084
VectorExtractElementOpConvert, VectorExtractOpConvert,
1040
1085
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
1041
1086
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
1042
- VectorInsertElementOpConvert, VectorInsertOpConvert ,
1043
- VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1087
+ VectorToElementOpConvert, VectorInsertElementOpConvert ,
1088
+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
1044
1089
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
1045
1090
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
1046
1091
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
0 commit comments