Skip to content

Commit 75524de

Browse files
authored
[mlir][xegpu] Relax rank restriction of TensorDescType (#145916)
1 parent d286540 commit 75524de

File tree

15 files changed

+277
-178
lines changed

15 files changed

+277
-178
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,18 @@ def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_td
4242
}];
4343

4444
let parameters = (ins
45-
OptionalParameter<"MemorySpaceAttr">: $memory_space,
46-
OptionalParameter<"IntegerAttr", "1">: $array_length,
47-
OptionalParameter<"BoolAttr", "true">: $boundary_check
45+
DefaultValuedParameter<
46+
"MemorySpaceAttr",
47+
"MemorySpaceAttr::get($_ctxt, xegpu::MemorySpace::Global)",
48+
"Data memory location">: $memory_space,
49+
DefaultValuedParameter<
50+
"IntegerAttr",
51+
"IntegerAttr::get(IntegerType::get($_ctxt, 64), 1)",
52+
"Number of continuous blocks to load">: $array_length,
53+
DefaultValuedParameter<
54+
"BoolAttr",
55+
"BoolAttr::get($_ctxt, true)",
56+
"Checking the out of boundary access">: $boundary_check
4857
);
4958

5059
let builders = [
@@ -67,8 +76,8 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
6776
TensorDesc is located, `Global` device memory or `Shared` local memory.
6877
It is default to `Global`.
6978

70-
2. `chunk_size`: indicates number of contiguous elements accessed for each
71-
offset, default is 1. It is used with `scattered` attr only.
79+
2. `chunk_size`: Specifies the number of contiguous elements accessed per offset.
80+
The default value is 1.
7281
}];
7382

7483
let parameters = (ins
@@ -91,6 +100,12 @@ def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scat
91100
)>
92101
];
93102

103+
let extraClassDeclaration = [{
104+
int64_t getChunkSizeAsInt() {
105+
return getChunkSize().getInt();
106+
}
107+
}];
108+
94109
let genVerifyDecl = 1;
95110
}
96111

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
287287
transpose is another Intel hardware feature, which will do transpose
288288
operation when loading the data if the bit width of the data type is
289289
fp32 or fp64. It implies that vnni and transpose cannot exit at the
290-
same time.
290+
same time. It is only available to 1D or 2D blocked tensor_desc.
291291

292292
In SIMT mode, result vector represents the data to be loaded by each work-item.
293293

@@ -343,6 +343,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
343343
by the TensorDesc. It takes a set of optional cache hints for each level
344344
of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
345345
Corresponding cache hint attribute will be masked.
346+
It is only available to 1D or 2D blocked tensor_desc.
346347

347348
In SIMT mode, the input vector represents the data to be stored by each work-item.
348349

@@ -757,6 +758,8 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
757758
let assemblyFormat = [{
758759
$TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
759760
}];
761+
762+
let hasVerifier = 1;
760763
}
761764

762765
def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64,
1717
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
1818
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
1919
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
20-
def XeGPU_DpasOprType: VectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
21-
def XeGPU_DpasResType: VectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
22-
def XeGPU_OffsetType: VectorOfRankAndType<[1], [Index]>;
23-
def XeGPU_MaskType: AnyTypeOf<[VectorOfRankAndType<[1], [I1]>, I1]>;
24-
def XeGPU_ValueType: AnyTypeOf<[VectorOfRankAndType<[1,2,3,4], [XeGPU_ScalarType]>, XeGPU_ScalarType]>;
25-
def XeGPU_Vector2DType: VectorOfRankAndType<[2], [XeGPU_ScalarType]>;
20+
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
21+
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
22+
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
23+
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
24+
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
25+
def XeGPU_Vector2DType: FixedVectorOfRankAndType<[2], [XeGPU_ScalarType]>;
2626

2727
// common base class for types in XeGPU dialect
2828
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -118,7 +118,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
118118
];
119119

120120
let extraClassDeclaration = [{
121-
using TensorType::clone;
122121
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
123122
using mlir::ShapedType::Trait<TensorDescType>::getRank;
124123
using mlir::ShapedType::Trait<TensorDescType>::getNumElements;
@@ -157,6 +156,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
157156
return MemorySpace::Global;
158157
}
159158

159+
// get the ArrayLength for blocked TensorDesc
160160
int getArrayLength() {
161161
auto attr = getEncoding();
162162
auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
@@ -181,13 +181,12 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
181181
return bool(getEncodingAsScatterTensorDescAttr());
182182
}
183183

184-
int getChunkSize() {
184+
// get the ChunkSize for scattered TensorDesc
185+
int getChunkSizeAsInt() {
185186
auto attr = getEncoding();
186187
auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
187-
assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
188-
if (scatter_attr)
189-
return scatter_attr.getChunkSize().getInt();
190-
return 1;
188+
assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189+
return scatter_attr.getChunkSizeAsInt();
191190
}
192191

193192
/// Helper to drop all layout information from the TensorDesc type.

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ LogicalResult ScatterTensorDescAttr::verify(
129129
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
130130
MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
131131
int64_t chunkSize = chunk_size.getInt();
132-
SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
133-
16, 32, 64, 128, 256};
134-
if (!llvm::is_contained(supportedChunkSizes, chunkSize))
132+
if (chunkSize <= 0)
135133
return emitError() << "invalid chunk size";
136134

137135
return success();
@@ -310,15 +308,16 @@ LogicalResult TensorDescType::verify(
310308
llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
311309
mlir::Attribute encoding, mlir::Attribute layout) {
312310
size_t rank = shape.size();
313-
if (rank != 1 && rank != 2)
314-
return emitError() << "expected 1D or 2D tensor";
311+
312+
if (rank == 0)
313+
return emitError() << "expected non-zero rank tensor";
315314

316315
auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
317316
if (blockAttr) {
318317
MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
319-
if (rank == 2 && memorySpaceAttr &&
318+
if (rank > 1 && memorySpaceAttr &&
320319
memorySpaceAttr.getValue() == MemorySpace::SLM)
321-
return emitError() << "SLM is not supported for 2D block tensor";
320+
return emitError() << "SLM is only supported for 1D block tensor";
322321
}
323322

324323
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
@@ -329,22 +328,18 @@ LogicalResult TensorDescType::verify(
329328
: 1;
330329
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
331330
if (scatterAttr) {
332-
// Expected tensor ranks for scattered data:
333-
// - 1D tensor for fully non-contiguous elements (chunk size == 1)
334-
// - 2D tensor for scattered blocks (chunk size > 1)
335-
unsigned chunkSize = scatterAttr.getChunkSize().getInt();
331+
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
336332
if (rank == 1 && chunkSize != 1)
337333
return emitError() << "expected non-contiguous elements for 1D tensor";
338-
if (rank == 2 && chunkSize < 2)
339-
return emitError() << "expected chunk blocks for 2D tensor";
334+
340335
// If chunk size > 1, the second dimension of the tensor shape must be
341-
// equal to chunk size and it must be a multiple of the packing factor.
336+
// equal to chunk size and it must be a multiple of the
337+
// chunkAlignmentFactor.
342338
if (chunkSize > 1) {
343339
if (shape.back() != chunkSize)
344-
return emitError() << "expected tensor shape[1] to match chunk size";
340+
return emitError() << "expected last dim of tensor to match chunk size";
345341
if (shape.back() % chunkAlignmentFactor != 0)
346-
return emitError() << "expected tensor shape[1] to be a multiple of "
347-
"chunk alignment factor "
342+
return emitError() << "expected last dim of tensor to be a multiple of "
348343
<< chunkAlignmentFactor;
349344
}
350345
}
@@ -357,17 +352,13 @@ LogicalResult TensorDescType::verify(
357352
auto laneData = layoutAttr.getLaneData();
358353
if (scatterAttr && laneData) {
359354
// Validate subgroup mapping rules for scattered tensors.
360-
// A work-item's slice of the tensor with shape [sg_size] or
361-
// [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
362-
// respectively, the mapping should reflect that. This is because each
363-
// work item access data in 32 bit granularity.
364-
365-
if (rank > 1 && laneData[0] != 1)
355+
// if chunkSize > 1, the last dimension of the tensor should
356+
// be distributed in the units divisible by chunkAlignmentFactor.
357+
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
358+
if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
366359
return emitError()
367-
<< "cannot map over non-contiguous scattered row elements";
368-
if (laneData[rank - 1] != chunkAlignmentFactor)
369-
return emitError() << "work item data mapping must match the number of "
370-
"contiguous elements";
360+
<< "expected last dim of lane_data to be a multiple of: "
361+
<< chunkAlignmentFactor;
371362
}
372363

373364
if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
8181
auto maskShape = getShapeOf(maskTy);
8282
auto valueShape = getShapeOf(valueTy);
8383
auto tdescShape = getShapeOf(tdescTy);
84-
auto chunkSize = tdescTy.getChunkSize();
84+
auto chunkSize = tdescTy.getChunkSizeAsInt();
8585

8686
if (valueTy.getElementType() != tdescTy.getElementType())
8787
return emitError()
8888
<< "Value should have the same element type as TensorDesc.";
8989

90-
if (tdescShape[0] != maskShape[0])
90+
llvm::SmallVector<int64_t> expectedMaskShape(tdescShape);
91+
if (chunkSize > 1)
92+
expectedMaskShape.pop_back();
93+
if (expectedMaskShape != maskShape)
9194
return emitError()
92-
<< "dim-0 of the Mask and TensorDesc should be the same.";
95+
<< "Mask should match TensorDesc except the chunk size dim.";
9396

9497
// a valid shape for SIMT case
9598
if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
@@ -203,11 +206,9 @@ LogicalResult CreateNdDescOp::verify() {
203206
"is a memref) should match with each other.");
204207

205208
// check result TensorDesc rank
206-
invalidRank = (getType().getRank() > 2 || getType().getRank() > rank);
207-
208-
if (invalidRank)
209+
if (getType().getRank() > rank)
209210
return emitOpError(
210-
"Expecting the TensorDesc rank is up to 2 and not greater than the "
211+
"Expecting the TensorDesc rank is not greater than the "
211212
"ranks of shape, strides, offsets or the memref source.");
212213

213214
if (invalidElemTy)
@@ -247,12 +248,12 @@ LogicalResult LoadNdOp::verify() {
247248
auto tdescTy = getTensorDescType();
248249
auto valueTy = getType();
249250

250-
if (tdescTy.getRank() > 2)
251-
return emitOpError("Expecting a 1D/2D TensorDesc.\n");
252-
253251
if (tdescTy.isScattered())
254252
return emitOpError("Expects a non-scattered TensorDesc.\n");
255253

254+
if (tdescTy.getRank() > 2)
255+
return emitOpError("Expects a 1D or 2D TensorDesc.\n");
256+
256257
if (!valueTy)
257258
return emitOpError("Invalid result, it should be a VectorType.\n");
258259

@@ -316,15 +317,13 @@ LogicalResult LoadNdOp::verify() {
316317
}
317318

318319
auto array_len = tdescTy.getArrayLength();
319-
if (array_len > 1) {
320+
if (array_len > 1)
320321
tdescShape.insert(tdescShape.begin(), array_len);
321-
}
322322

323-
if (tdescShape != valueShape) {
323+
if (tdescShape != valueShape)
324324
return emitOpError() << "Result shape " << makeString(valueShape)
325325
<< " is not consistent with tensor descriptor "
326326
<< tdescTy;
327-
}
328327

329328
return success();
330329
}
@@ -336,12 +335,12 @@ LogicalResult StoreNdOp::verify() {
336335
auto dstTy = getTensorDescType(); // Tile
337336
auto valTy = getValueType(); // Vector
338337

339-
if (dstTy.getRank() > 2)
340-
return emitOpError("Expecting a 1D/2D TensorDesc.\n");
341-
342338
if (dstTy.isScattered())
343339
return emitOpError("Expects a non-scattered TensorDesc.\n");
344340

341+
if (dstTy.getRank() > 2)
342+
return emitOpError("Expects a 1D or 2D TensorDesc.\n");
343+
345344
if (!valTy)
346345
return emitOpError("Expecting a VectorType result.\n");
347346

@@ -370,22 +369,21 @@ LogicalResult StoreNdOp::verify() {
370369
return emitOpError()
371370
<< "TensorDesc doesn't need LayoutAttr for SIMT code";
372371

373-
if (tdescElems % valueElems) {
372+
if (tdescElems % valueElems)
374373
return emitOpError()
375374
<< "Value shape " << makeString(getShapeOf(valTy))
376375
<< " is not a valid distribution for tensor descriptor " << dstTy;
377-
}
376+
378377
return success();
379378
}
380379

381380
// SIMD code should have the same shape as the tensor descriptor.
382381
auto tdescShape = getShapeOf(dstTy);
383382
auto valueShape = getShapeOf(valTy);
384-
if (tdescShape != valueShape) {
383+
if (tdescShape != valueShape)
385384
return emitOpError() << "Value shape " << makeString(valueShape)
386385
<< " is not consistent with tensor descriptor "
387386
<< dstTy;
388-
}
389387

390388
return success();
391389
}
@@ -449,25 +447,8 @@ LogicalResult CreateDescOp::verify() {
449447
<< ", TensorDesc: " << tdescMemorySpace;
450448

451449
// check total size
452-
auto chunkSize = tdescTy.getChunkSize();
453-
auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
454-
auto bitsPerLane = elemBits * chunkSize;
455-
if (chunkSize > 1 && bitsPerLane % 32) {
456-
// For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
457-
// For 32-bit data, the hardware can support larger larger chunk size. So
458-
// we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
459-
// But this requires the total size is 32 bit aligned to make the
460-
// optimization work.
461-
return emitOpError(
462-
"access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
463-
}
464-
465-
auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
466-
if (elemBits * tdescTy.getNumElements() > lscConstraints)
467-
return emitOpError("total access size (simd_lanes * chunk_size * "
468-
"sizeof(elemTy)) is upto 512 bytes.");
469-
470-
SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
450+
auto chunkSize = tdescTy.getChunkSizeAsInt();
451+
SmallVector<int64_t> shape(getOffsetsType().getShape());
471452
if (chunkSize != 1)
472453
shape.push_back(chunkSize);
473454

@@ -563,6 +544,23 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
563544
build(builder, state, tensorDesc, ofrs);
564545
}
565546

547+
LogicalResult UpdateOffsetOp::verify() {
548+
auto tdescTy = getTensorDescType();
549+
if (!tdescTy.isScattered())
550+
return emitOpError("Expects a scattered TensorDesc.\n");
551+
552+
SmallVector<int64_t> expectedOffsetShape = getShapeOf(tdescTy);
553+
SmallVector<int64_t> offsetShape = getShapeOf(getOffsetsType());
554+
if (tdescTy.getChunkSizeAsInt() > 1)
555+
expectedOffsetShape.pop_back();
556+
557+
if (expectedOffsetShape != offsetShape)
558+
return emitOpError(
559+
"Offsets should match TensorDesc except the chunk size dim.");
560+
561+
return success();
562+
}
563+
566564
//===----------------------------------------------------------------------===//
567565
// XeGPU_DpasOp
568566
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
303303
// If the encoding is a ScatterTensorDescAttr, we need to
304304
// potentially adjust the chunk size based on the inst_data.
305305
if (tdescTy.isScattered()) {
306-
auto scatterAttr =
307-
llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
308-
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
306+
int64_t chunkSize = tdescTy.getChunkSizeAsInt();
309307

310308
if (chunkSize > 1) {
311309
int64_t blockedChunkSize = chunkSize;
@@ -315,7 +313,7 @@ void XeGPUBlockingPass::runOnOperation() {
315313

316314
// To create a new attribute with a different chunk_size:
317315
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
318-
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
316+
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
319317

320318
encoding = newEncoding;
321319
}

0 commit comments

Comments
 (0)