|
7 | 7 | using namespace mlir::xegpu::uArch;
|
8 | 8 | using namespace mlir::xegpu::uArch::PVCuArch;
|
9 | 9 |
|
10 |
| -namespace llvm { |
11 |
| -namespace yaml { |
12 |
| -template <> |
13 |
| -struct MappingTraits<XeCoreInfo> { |
14 |
| - static void mapping(IO &io, XeCoreInfo &xe_core) { |
15 |
| - io.mapRequired("num_threads", xe_core.num_threads); |
16 |
| - io.mapRequired("shared_memory", xe_core.shared_memory); |
17 |
| - io.mapRequired("num_vector_units", xe_core.num_vector_units); |
18 |
| - io.mapRequired("num_matrix_units", xe_core.num_matrix_units); |
| 10 | +namespace mlir { |
| 11 | +namespace xegpu { |
| 12 | +namespace uArch { |
| 13 | +namespace PVCuArch { |
| 14 | +bool DPASInstruction::checkSupportedMMATypes(mlir::Type AType, mlir::Type BType, |
| 15 | + mlir::Type CType, |
| 16 | + mlir::Type DType) { |
| 17 | + if (AType.isF16() || BType.isF16()) { |
| 18 | + if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) || |
| 19 | + (!DType.isF32() && !DType.isF16())) |
| 20 | + llvm::errs() |
| 21 | + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " |
| 22 | + << "Supported types are:\n" |
| 23 | + << " Dst | Acc | A | B \n" |
| 24 | + << " f, hf | f, hf | hf | hf \n" |
| 25 | + << "AType: " << AType << " BType: " << BType << " CType: " << CType |
| 26 | + << " DType: " << DType; |
| 27 | + return false; |
| 28 | + } else if (AType.isBF16() || BType.isBF16()) { |
| 29 | + if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) || |
| 30 | + (!DType.isF32() && !DType.isBF16())) |
| 31 | + llvm::errs() |
| 32 | + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " |
| 33 | + << "Supported types are:\n" |
| 34 | + << " Dst | Acc | A | B \n" |
| 35 | + << " f, bf | f, bf | bf | bf \n" |
| 36 | + << "AType: " << AType << " BType: " << BType << " CType: " << CType |
| 37 | + << " DType: " << DType; |
| 38 | + return false; |
| 39 | + } else if (AType.isTF32() || BType.isTF32()) { |
| 40 | + if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) || |
| 41 | + (!DType.isF32())) |
| 42 | + llvm::errs() |
| 43 | + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " |
| 44 | + << "Supported types are:\n" |
| 45 | + << " Dst | Acc | A | B \n" |
| 46 | + << " f | f | tf32 | tf32 \n" |
| 47 | + << "AType: " << AType << " BType: " << BType << " CType: " << CType |
| 48 | + << " DType: " << DType; |
| 49 | + return false; |
| 50 | + } else if (!(AType.isInteger(2) || AType.isInteger(4) || |
| 51 | + AType.isInteger(8)) && |
| 52 | + !(BType.isInteger(2) || BType.isInteger(4) || |
| 53 | + BType.isInteger(8))) { |
| 54 | + llvm::errs() |
| 55 | + << "Unsupported dpas combinations of Dst, Acc, A and B matrices, " |
| 56 | + << "Supported types are:\n" |
| 57 | + << " Dst | Acc | A | B " |
| 58 | + " \n" |
| 59 | + << " ud, d | ud,d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 " |
| 60 | + << "AType: " << AType << " BType: " << BType << " CType: " << CType |
| 61 | + << " DType: " << DType; |
| 62 | + return false; |
19 | 63 | }
|
20 |
| -}; |
21 | 64 |
|
22 |
| -template <> |
23 |
| -struct MappingTraits<Xe2Plus> { |
24 |
| - static void mapping(IO &io, Xe2Plus &xe2plus) { |
25 |
| - io.mapRequired("xe_core", xe2plus.xe_core); |
| 65 | + return true; |
| 66 | +} |
| 67 | + |
| 68 | +std::vector<uint> DPASInstruction::getSupportedM(mlir::Type type) { |
| 69 | + return {1, 2, 3, 4, 5, 6, 7, 8}; |
| 70 | +} |
| 71 | + |
| 72 | +std::vector<uint> DPASInstruction::getSupportedK(mlir::Type type) { |
| 73 | + // assert if data type is not int or float type |
| 74 | + assert(type.isIntOrFloat() && "Matrix type must be int or float"); |
| 75 | + auto bitWidth = type.getIntOrFloatBitWidth(); |
| 76 | + uint kSize = -1; |
| 77 | + switch (bitWidth) { |
| 78 | + case 2: |
| 79 | + kSize = 64; |
| 80 | + break; |
| 81 | + case 4: |
| 82 | + kSize = 64; |
| 83 | + break; |
| 84 | + case 8: |
| 85 | + kSize = 32; |
| 86 | + break; |
| 87 | + case 16: |
| 88 | + kSize = 16; |
| 89 | + break; |
| 90 | + case 32: |
| 91 | + kSize = 8; |
| 92 | + break; |
| 93 | + default: |
| 94 | + llvm_unreachable("Invalid int or float"); |
| 95 | + } |
| 96 | +} |
| 97 | + |
| 98 | +std::vector<uint> DPASInstruction::getSupportedN(mlir::Type type) { |
| 99 | + return {16}; |
| 100 | +} |
| 101 | + |
| 102 | +std::vector<std::pair<unsigned, unsigned>> |
| 103 | +DPASInstruction::getSupportedMatrix(mlir::Type type, MatrixType matrixType) { |
| 104 | + auto combineVectors = [](const std::vector<unsigned> &a, |
| 105 | + const std::vector<unsigned> &b) |
| 106 | + -> std::vector<std::pair<unsigned, unsigned>> { |
| 107 | + std::vector<std::pair<unsigned, unsigned>> result; |
| 108 | + for (unsigned x : a) { |
| 109 | + for (unsigned y : b) { |
| 110 | + result.emplace_back(x, y); |
| 111 | + } |
| 112 | + } |
| 113 | + return result; |
| 114 | + }; |
| 115 | + |
| 116 | + auto M = getSupportedM(type); |
| 117 | + auto K = getSupportedK(type); |
| 118 | + auto N = getSupportedN(type); |
| 119 | + std::vector<std::pair<unsigned, unsigned>> resultMatrix; |
| 120 | + |
| 121 | + switch (matrixType) { |
| 122 | + case MatrixType::MatrixA: |
| 123 | + resultMatrix = combineVectors(M, K); |
| 124 | + break; |
| 125 | + case MatrixType::MatrixB: |
| 126 | + resultMatrix = combineVectors(K, N); |
| 127 | + break; |
| 128 | + case MatrixType::MatrixC: |
| 129 | + resultMatrix = combineVectors(M, N); |
| 130 | + break; |
| 131 | + case MatrixType::MatrixD: |
| 132 | + resultMatrix = combineVectors(M, N); |
| 133 | + break; |
| 134 | + default: |
| 135 | + break; |
26 | 136 | }
|
27 |
| -}; |
28 |
| -} // namespace yaml |
29 |
| -} // namespace llvm |
| 137 | +} |
| 138 | + |
| 139 | +} // namespace PVCuArch |
| 140 | +} // namespace uArch |
| 141 | +} // namespace xegpu |
| 142 | +} // namespace mlir |
| 143 | + |
| 144 | +// namespace llvm { |
| 145 | +// namespace yaml { |
| 146 | +// template <> |
| 147 | +// struct MappingTraits<XeCoreInfo> { |
| 148 | +// static void mapping(IO &io, XeCoreInfo &xe_core) { |
| 149 | +// io.mapRequired("num_threads", xe_core.num_threads); |
| 150 | +// io.mapRequired("shared_memory", xe_core.shared_memory); |
| 151 | +// io.mapRequired("num_vector_units", xe_core.num_vector_units); |
| 152 | +// io.mapRequired("num_matrix_units", xe_core.num_matrix_units); |
| 153 | +// } |
| 154 | +// }; |
| 155 | + |
| 156 | +// template <> |
| 157 | +// struct MappingTraits<Xe2Plus> { |
| 158 | +// static void mapping(IO &io, Xe2Plus &xe2plus) { |
| 159 | +// io.mapRequired("xe_core", xe2plus.xe_core); |
| 160 | +// } |
| 161 | +// }; |
| 162 | +// } // namespace yaml |
| 163 | +// } // namespace llvm |
30 | 164 |
|
31 | 165 | // namespace mlir {
|
32 | 166 | // namespace xe_gpu {
|
|
0 commit comments