From 5fc7342a836cea6409719bdbdb5b69e8e4e7c570 Mon Sep 17 00:00:00 2001 From: Jerry Shih Date: Mon, 6 May 2024 01:55:50 -0700 Subject: [PATCH] [mlir][linalg][nfc] Fix `linalg.matmul_transpose_a` def. The `matmul_transpose_a` input data format should be `KxM * KxN` instead of current `KxN * KxM` format. It's a NFC fix. --- .../mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml | 2 +- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index fad234a9dcae9..abb79278eddd4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1336,7 +1336,7 @@ structured_op: !LinalgStructuredOpConfig name: C kind: output_tensor type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> - !LinalgOperandDefConfig name: cast kind: type_fn_attr diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 43410aaa6af1b..59b3ba914eaab 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -429,8 +429,8 @@ def quantized_matmul( @linalg_structured_op def matmul_transpose_a( - A=TensorDef(T1, S.K, S.N), - B=TensorDef(T2, S.K, S.M), + A=TensorDef(T1, S.K, S.M), + B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), cast=TypeFnAttrDef(default=TypeFn.cast_signed), ):