Skip to content

Commit 3264290

Browse files
authored
[mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix (#142786)
Depends on #142784.
1 parent b27ab06 commit 3264290

File tree

5 files changed

+102
-8
lines changed

5 files changed

+102
-8
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,13 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) {
558558

559559
static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
560560
Type opType) {
561+
if (isa<spirv::CooperativeMatrixType>(opType)) {
562+
auto denseAttr = dyn_cast<DenseElementsAttr>(value);
563+
if (!denseAttr || !denseAttr.isSplat())
564+
return op.emitOpError("expected a splat dense attribute for cooperative "
565+
"matrix constant, but found ")
566+
<< denseAttr;
567+
}
561568
if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
562569
auto valueType = llvm::cast<TypedAttr>(value).getType();
563570
if (valueType != opType)

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
14681468
}
14691469

14701470
auto resultID = operands[1];
1471-
if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1472-
auto attr = DenseElementsAttr::get(vectorType, elements);
1471+
if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1472+
auto attr = DenseElementsAttr::get(shapedType, elements);
14731473
// For normal constants, we just record the attribute (and its type) for
14741474
// later materialization at use sites.
1475-
constantMap.try_emplace(resultID, attr, resultType);
1475+
constantMap.try_emplace(resultID, attr, shapedType);
14761476
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
14771477
auto attr = opBuilder.getArrayAttr(elements);
14781478
constantMap.try_emplace(resultID, attr, resultType);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,18 +845,44 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
845845
return 0;
846846
}
847847

848+
int64_t numberOfConstituents = shapedType.getDimSize(dim);
848849
uint32_t resultID = getNextID();
849850
SmallVector<uint32_t, 4> operands = {typeID, resultID};
850-
operands.reserve(shapedType.getDimSize(dim) + 2);
851851
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
852-
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
853-
index[dim] = i;
852+
853+
// "If the Result Type is a cooperative matrix type, then there must be only
854+
// one Constituent, with scalar type matching the cooperative matrix Component
855+
// Type, and all components of the matrix are initialized to that value."
856+
// (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
857+
if (isa<spirv::CooperativeMatrixType>(constType)) {
858+
if (!valueAttr.isSplat()) {
859+
emitError(
860+
loc,
861+
"cannot serialize a non-splat value for a cooperative matrix type");
862+
return 0;
863+
}
864+
// numberOfConstituents is 1, so we only need one more elements in the
865+
// SmallVector, so the total is 3 (1 + 2).
866+
operands.reserve(3);
867+
// We set dim directly to `shapedType.getRank()` so the recursive call
868+
// directly returns the scalar type.
854869
if (auto elementID = prepareDenseElementsConstant(
855-
loc, elementType, valueAttr, dim + 1, index)) {
870+
loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
856871
operands.push_back(elementID);
857872
} else {
858873
return 0;
859874
}
875+
} else {
876+
operands.reserve(numberOfConstituents + 2);
877+
for (int i = 0; i < numberOfConstituents; ++i) {
878+
index[dim] = i;
879+
if (auto elementID = prepareDenseElementsConstant(
880+
loc, elementType, valueAttr, dim + 1, index)) {
881+
operands.push_back(elementID);
882+
} else {
883+
return 0;
884+
}
885+
}
860886
}
861887
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
862888
encodeInstructionInto(typesGlobalValues, opcode, operands);

mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ func.func @const() -> () {
6262
// CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
6363
// CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
6464
// CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
65+
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
66+
// CHECK: spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
67+
// CHECK: spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
68+
// CHECK: spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
6569

6670
%0 = spirv.Constant true
6771
%1 = spirv.Constant 42 : i32
@@ -73,6 +77,10 @@ func.func @const() -> () {
7377
%7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
7478
%8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
7579
%9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
80+
%10 = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
81+
%11 = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
82+
%12 = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
83+
%13 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
7684
return
7785
}
7886

@@ -132,6 +140,31 @@ func.func @value_result_num_elements_mismatch() -> () {
132140

133141
// -----
134142

143+
func.func @coop_matrix_const_non_splat() -> () {
144+
// expected-error @+1 {{expected a splat dense attribute for cooperative matrix constant, but found}}
145+
%0 = spirv.Constant dense<[[1.0, 2.0], [3.0, 4.0]]> : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
146+
return
147+
}
148+
149+
// -----
150+
151+
func.func @coop_matrix_const_non_dense() -> () {
152+
// expected-error @+2 {{floating point value not valid for specified type}}
153+
%0 = spirv.Constant 0.000000e+00 : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
154+
return
155+
}
156+
157+
// -----
158+
159+
func.func @coop_matrix_const_wrong_type() -> () {
160+
// expected-error @below {{unexpected decimal integer literal for a floating point value}}
161+
// expected-note @+1 {{add a trailing dot to make the literal a float}}
162+
%0 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
163+
return
164+
}
165+
166+
// -----
167+
135168
//===----------------------------------------------------------------------===//
136169
// spirv.EntryPoint
137170
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
1+
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
22

33
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
44
// CHECK-LABEL: @bool_const
@@ -277,4 +277,32 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
277277
%signed_minus_one = spirv.Constant -1 : si16
278278
spirv.ReturnValue %signed_minus_one : si16
279279
}
280+
281+
// CHECK-LABEL: @coop_matrix_const_zero_f32
282+
spirv.func @coop_matrix_const_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
283+
// CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
284+
%coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
285+
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
286+
}
287+
288+
// CHECK-LABEL: @coop_matrix_const_non_zero_f32
289+
spirv.func @coop_matrix_const_non_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
290+
// CHECK: {{%.*}} = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
291+
%coop = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
292+
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
293+
}
294+
295+
// CHECK-LABEL: @coop_matrix_const_zero_i8
296+
spirv.func @coop_matrix_const_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" {
297+
// CHECK: {{%.*}} = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
298+
%coop = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
299+
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
300+
}
301+
302+
// CHECK-LABEL: @coop_matrix_const_non_zero_i8
303+
spirv.func @coop_matrix_const_non_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" {
304+
// CHECK: {{%.*}} = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
305+
%coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
306+
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
307+
}
280308
}

0 commit comments

Comments
 (0)