Skip to content

Commit 4a68562

Browse files
authored
[mlir][spirv] Reject coop matrix operands on unsupported arithmetic ops (#147230)
Cooperative matrix operands are only supported for `add/sub/mul/div` binary arithmetic ops, but currently all binary arithmetic ops accept cooperative matrix operands, including `mod/rem`. This change fixes this behaviour.
1 parent 517cda1 commit 4a68562

File tree

2 files changed

+81
-21
lines changed

2 files changed

+81
-21
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2424
SPIRV_BinaryOp<mnemonic, type, type,
2525
!listconcat(traits,
2626
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
27-
// In addition to normal types arithmetic instructions can support cooperative
28-
// matrix.
27+
let arguments = (ins
28+
SPIRV_ScalarOrVectorOf<type>:$operand1,
29+
SPIRV_ScalarOrVectorOf<type>:$operand2
30+
);
31+
32+
let results = (outs
33+
SPIRV_ScalarOrVectorOf<type>:$result
34+
);
35+
let assemblyFormat = "operands attr-dict `:` type($result)";
36+
}
37+
38+
class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
39+
list<Trait> traits = []> :
40+
// Operands type same as result type.
41+
SPIRV_BinaryOp<mnemonic, type, type,
42+
!listconcat(traits,
43+
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
44+
// In addition to normal types these arithmetic instructions can support
45+
// cooperative matrix.
2946
let arguments = (ins
3047
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
3148
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +99,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
8299

83100
// -----
84101

85-
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
102+
def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
86103
let summary = "Floating-point addition of Operand 1 and Operand 2.";
87104

88105
let description = [{
@@ -104,7 +121,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
104121

105122
// -----
106123

107-
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
124+
def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
108125
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
109126

110127
let description = [{
@@ -154,7 +171,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
154171

155172
// -----
156173

157-
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
174+
def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
158175
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
159176

160177
let description = [{
@@ -229,7 +246,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
229246

230247
// -----
231248

232-
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
249+
def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
233250
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
234251

235252
let description = [{
@@ -251,9 +268,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
251268

252269
// -----
253270

254-
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
255-
SPIRV_Integer,
256-
[Commutative, UsableInSpecConstantOp]> {
271+
def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
272+
SPIRV_Integer,
273+
[Commutative, UsableInSpecConstantOp]> {
257274
let summary = "Integer addition of Operand 1 and Operand 2.";
258275

259276
let description = [{
@@ -322,9 +339,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
322339

323340
// -----
324341

325-
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
326-
SPIRV_Integer,
327-
[Commutative, UsableInSpecConstantOp]> {
342+
def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
343+
SPIRV_Integer,
344+
[Commutative, UsableInSpecConstantOp]> {
328345
let summary = "Integer multiplication of Operand 1 and Operand 2.";
329346

330347
let description = [{
@@ -354,9 +371,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
354371

355372
// -----
356373

357-
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
358-
SPIRV_Integer,
359-
[UsableInSpecConstantOp]> {
374+
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
375+
SPIRV_Integer,
376+
[UsableInSpecConstantOp]> {
360377
let summary = "Integer subtraction of Operand 2 from Operand 1.";
361378

362379
let description = [{
@@ -460,9 +477,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
460477

461478
// -----
462479

463-
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
464-
SPIRV_Integer,
465-
[UsableInSpecConstantOp]> {
480+
def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
481+
SPIRV_Integer,
482+
[UsableInSpecConstantOp]> {
466483
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
467484

468485
let description = [{
@@ -622,9 +639,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
622639

623640
// -----
624641

625-
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
626-
SPIRV_Integer,
627-
[UnsignedOp, UsableInSpecConstantOp]> {
642+
def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
643+
SPIRV_Integer,
644+
[UnsignedOp, UsableInSpecConstantOp]> {
628645
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
629646

630647
let description = [{

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
577577
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
578578
spirv.Return
579579
}
580+
581+
// -----
582+
583+
// These binary arithmetic instructions do not support coop matrix operands.
584+
585+
spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
586+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
587+
%p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
588+
spirv.Return
589+
}
590+
591+
// -----
592+
593+
spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
594+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
595+
%p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
596+
spirv.Return
597+
}
598+
599+
// -----
600+
spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
601+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
602+
%p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
603+
spirv.Return
604+
}
605+
606+
// -----
607+
608+
spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
609+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
610+
%p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
611+
spirv.Return
612+
}
613+
614+
// -----
615+
616+
spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
617+
// expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
618+
%p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
619+
spirv.Return
620+
}
621+
622+
// -----

0 commit comments

Comments
 (0)