Skip to content

Commit 2b28d10

Browse files
[mlir][SCF][GPU] Add DeviceMaskingAttrInterface (#146943)
This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface. Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids. The revision also connects to scf::ForallOp and uses the new attribute to implement warp specialization. The implementation is in the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active. In the first implementation the masking is a bitfield that specifies for each processing unit whether it is active or not. In the future, we may want to implement this as a symbol to refer to dynamically defined values. Extending op semantics with an operand is deemed too intrusive at this time. --------- Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
1 parent 499e656 commit 2b28d10

File tree

13 files changed

+461
-63
lines changed

13 files changed

+461
-63
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,17 @@ def GPULaneMappingAttr
252252
}];
253253
}
254254

255+
def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
256+
DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] > {
257+
let parameters = (ins "uint64_t":$mask);
258+
let assemblyFormat = "`<` params `>`";
259+
let description = [{
260+
Attribute describing how to filter the processing units that a region is
261+
mapped to. The masking is a bitfield that specifies for each processing
262+
unit whether it is active or not.
263+
}];
264+
}
265+
255266
def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [
256267
DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > {
257268
let parameters = (ins

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "mlir/Dialect/SCF/IR/SCF.h"
1313
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
14-
#include "mlir/IR/OpImplementation.h"
1514
#include "mlir/IR/PatternMatch.h"
1615

1716
namespace mlir {
@@ -57,7 +56,7 @@ mapForallToBlocksImpl(RewriterBase &rewriter, TransformOpInterface transformOp,
5756
DiagnosedSilenceableFailure
5857
mapOneForallToThreadsImpl(RewriterBase &rewriter,
5958
std::optional<TransformOpInterface> transformOp,
60-
scf::ForallOp forallOp, ArrayRef<int64_t> blockDims,
59+
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
6160
int64_t warpSize, bool syncAfterDistribute);
6261

6362
/// Search `scf.forall` ops nested under `target` and map each such op to an

mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,22 @@ struct GpuIdBuilder {
7777
/// used for indexing rewrites as well as 3D sizes for predicate generation.
7878
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
7979
/// used for indexing rewrites as well as 1D sizes for predicate generation.
80+
/// If `mask` is provided, it will be used to filter the active blocks.
8081
struct GpuBlockIdBuilder : public GpuIdBuilder {
81-
GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
82+
GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
83+
DeviceMaskingAttrInterface mask = nullptr);
8284
};
8385

8486
/// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
8587
/// If `useLinearMapping` is false, the `idBuilder` method returns 3D values
8688
/// used for indexing rewrites as well as 3D sizes for predicate generation.
8789
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
8890
/// used for indexing rewrites as well as 1D sizes for predicate generation.
91+
/// If `mask` is provided, it will be used to filter the active warpgroups.
8992
struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
9093
GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
91-
bool useLinearMapping = false);
94+
bool useLinearMapping = false,
95+
DeviceMaskingAttrInterface mask = nullptr);
9296
int64_t warpSize = 32;
9397
/// In the future this may be configured by the transformation.
9498
static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -99,9 +103,11 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
99103
/// used for indexing rewrites as well as 3D sizes for predicate generation.
100104
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
101105
/// used for indexing rewrites as well as 1D sizes for predicate generation.
106+
/// If `mask` is provided, it will be used to filter the active warps.
102107
struct GpuWarpIdBuilder : public GpuIdBuilder {
103108
GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
104-
bool useLinearMapping = false);
109+
bool useLinearMapping = false,
110+
DeviceMaskingAttrInterface mask = nullptr);
105111
int64_t warpSize = 32;
106112
};
107113

@@ -110,16 +116,20 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
110116
/// used for indexing rewrites as well as 3D sizes for predicate generation.
111117
/// If `useLinearMapping` is true, the `idBuilder` method returns nD values
112118
/// used for indexing rewrites as well as 1D sizes for predicate generation.
119+
/// If `mask` is provided, it will be used to filter the active threads.
113120
struct GpuThreadIdBuilder : public GpuIdBuilder {
114-
GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
121+
GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
122+
DeviceMaskingAttrInterface mask = nullptr);
115123
};
116124

117125
/// Builder for lane id.
118126
/// The `idBuilder` method returns nD values used for indexing rewrites as well
119127
/// as 1D sizes for predicate generation.
120128
/// This `useLinearMapping` case is the only supported case.
129+
/// If `mask` is provided, it will be used to filter the active lanes.
121130
struct GpuLaneIdBuilder : public GpuIdBuilder {
122-
GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
131+
GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
132+
DeviceMaskingAttrInterface mask = nullptr);
123133
int64_t warpSize = 32;
124134
};
125135

mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,68 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> {
6060
];
6161
}
6262

63+
def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
64+
let cppNamespace = "::mlir";
65+
let description = [{
66+
Attribute interface describing how to filter the processing units that a
67+
region is mapped to.
68+
69+
For instance, consider the following example mask which specifies processing
70+
units 2, 4 and 5 are active:
71+
```
72+
8 4 0
73+
mask : 0 0 0 1 1 0 1 0 0
74+
```
75+
The logical ID for an active processing unit is defined as its position
76+
relative to the other active processing units. In this example, we have:
77+
```
78+
Processing Unit LogicalID
79+
0 N/A
80+
1 N/A
81+
2 0
82+
3 N/A
83+
4 1
84+
5 2
85+
6 N/A
86+
7 N/A
87+
```
88+
}];
89+
90+
let methods = [
91+
InterfaceMethod<
92+
/*desc=*/[{
93+
Create the logical active id for a given physical id.
94+
Expects a physicalLinearMappingId of I64Type.
95+
}],
96+
/*retTy=*/"Value",
97+
/*methodName=*/"createLogicalLinearMappingId",
98+
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
99+
>,
100+
InterfaceMethod<
101+
/*desc=*/[{
102+
Return the dynamic condition determining whether a given physical id is
103+
active under the mask.
104+
Expects a physicalLinearMappingId of I64Type.
105+
}],
106+
/*retTy=*/"Value",
107+
/*methodName=*/"createIsActiveIdPredicate",
108+
/*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
109+
>,
110+
InterfaceMethod<
111+
/*desc=*/[{
112+
Return the maximal number of pysical ids supported.
113+
This is to account for temporary implementation limitations (e.g. i64)
114+
and fail gracefully with actionnable error messages.
115+
}],
116+
/*retTy=*/"int64_t",
117+
/*methodName=*/"getMaxNumPhysicalIds",
118+
/*args=*/(ins)
119+
>,
120+
];
121+
}
122+
63123
def DeviceMappingArrayAttr :
64-
TypedArrayAttrBase<DeviceMappingAttrInterface,
124+
TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>,
65125
"Device Mapping array attribute"> { }
66126

67127
#endif // MLIR_DEVICEMAPPINGINTERFACE

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
611611
/// Returns operations within scf.forall.in_parallel whose destination
612612
/// operand is the block argument `bbArg`.
613613
SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
614+
615+
/// Returns the subset of DeviceMappingArrayAttrs of type
616+
/// DeviceMappingAttrInterface.
617+
SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
618+
619+
/// Returns the at most one DeviceMaskingAttrInterface in the mapping.
620+
/// If more than one DeviceMaskingAttrInterface is specified, returns
621+
/// failure. If no mapping is present, returns nullptr.
622+
FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
623+
624+
/// Returns true if the mapping specified for this forall op is linear.
625+
bool usesLinearMapping();
614626
}];
615627
}
616628

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
2020
MLIRFunctionInterfaces
2121
MLIRInferIntRangeInterface
2222
MLIRIR
23+
MLIRMathDialect
2324
MLIRMemRefDialect
2425
MLIRSideEffectInterfaces
2526
MLIRSupport

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
17+
#include "mlir/Dialect/Math/IR/Math.h"
1718
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
120121
: getMappingId();
121122
}
122123

124+
int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
125+
126+
/// 8 4 0
127+
/// Example mask : 0 0 0 1 1 0 1 0 0
128+
///
129+
/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
130+
/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
131+
///
132+
/// Example mask : 0 0 0 1 1 0 1 0 0
133+
/// Example filter: 0 0 0 0 1 1 1 1 1
134+
/// Intersection : 0 0 0 0 1 0 1 0 0
135+
/// PopCnt : 2
136+
Value GPUMappingMaskAttr::createLogicalLinearMappingId(
137+
OpBuilder &b, Value physicalLinearMappingId) const {
138+
Location loc = physicalLinearMappingId.getLoc();
139+
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
140+
Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
141+
Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
142+
filter = b.create<arith::SubIOp>(loc, filter, one);
143+
Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
144+
return b.create<math::CtPopOp>(loc, filteredId);
145+
}
146+
147+
/// 8 4 0
148+
/// Example mask : 0 0 0 1 1 0 1 0 0
149+
///
150+
/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2).
151+
/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
152+
///
153+
/// Example mask : 0 0 0 1 1 0 1 0 0
154+
/// Example filter: 0 0 0 1 0 0 0 0 0
155+
/// Intersection : 0 0 0 1 0 0 0 0 0
156+
/// Cmp : 1
157+
Value GPUMappingMaskAttr::createIsActiveIdPredicate(
158+
OpBuilder &b, Value physicalLinearMappingId) const {
159+
Location loc = physicalLinearMappingId.getLoc();
160+
Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask()));
161+
Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
162+
Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
163+
Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
164+
Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
165+
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero);
166+
}
167+
123168
int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
124169
return static_cast<int64_t>(getAddressSpace());
125170
}

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
1313
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1414
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1615
#include "mlir/Dialect/Arith/IR/Arith.h"
17-
#include "mlir/Dialect/Func/IR/FuncOps.h"
1816
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1917
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
2018
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -351,16 +349,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
351349
seen.insert(map);
352350
}
353351

354-
auto isLinear = [](Attribute a) {
355-
return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
352+
auto isLinear = [](DeviceMappingAttrInterface attr) {
353+
return attr.isLinearMapping();
356354
};
357-
if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
358-
!llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
355+
if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
356+
!llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
359357
return definiteFailureHelper(
360358
transformOp, forallOp,
361359
"cannot mix linear and non-linear mapping modes");
362360
}
363361

362+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
363+
forallOp.getDeviceMaskingAttr();
364+
if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
365+
!forallOp.usesLinearMapping()) {
366+
return definiteFailureHelper(
367+
transformOp, forallOp,
368+
"device masking is only available in linear mapping mode");
369+
}
370+
364371
return DiagnosedSilenceableFailure::success();
365372
}
366373

@@ -381,9 +388,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
381388
if (forallOp.getNumResults() > 0)
382389
return definiteFailureHelper(transformOp, forallOp,
383390
"only bufferized scf.forall can be mapped");
384-
bool useLinearMapping = cast<DeviceMappingAttrInterface>(
385-
forallOp.getMapping()->getValue().front())
386-
.isLinearMapping();
391+
bool useLinearMapping = forallOp.usesLinearMapping();
387392
// TODO: This would be more natural with support for Optional<EnumParameter>
388393
// in GPUDeviceMappingAttr.
389394
int64_t maxNumMappingsSupported =
@@ -436,8 +441,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
436441
assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
437442
"requires statically sized, normalized forall op");
438443
SmallVector<int64_t> tmpMappingSizes = numParallelIterations.value();
444+
SmallVector<DeviceMappingAttrInterface> forallMappingAttrsVec =
445+
forallOp.getDeviceMappingAttrs();
439446
SetVector<Attribute> forallMappingAttrs;
440-
forallMappingAttrs.insert_range(forallOp.getMapping()->getValue());
447+
forallMappingAttrs.insert_range(forallMappingAttrsVec);
441448
auto comparator = [](Attribute a, Attribute b) -> bool {
442449
return cast<DeviceMappingAttrInterface>(a).getMappingId() <
443450
cast<DeviceMappingAttrInterface>(b).getMappingId();
@@ -682,12 +689,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
682689

683690
// The BlockIdBuilder adapts to whatever is thrown at it.
684691
bool useLinearMapping = false;
685-
if (topLevelForallOp.getMapping()) {
686-
auto mappingAttr = cast<DeviceMappingAttrInterface>(
687-
topLevelForallOp.getMapping()->getValue().front());
688-
useLinearMapping = mappingAttr.isLinearMapping();
689-
}
690-
GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
692+
if (topLevelForallOp.getMapping())
693+
useLinearMapping = topLevelForallOp.usesLinearMapping();
694+
695+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
696+
topLevelForallOp.getDeviceMaskingAttr();
697+
assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
698+
assert((!*maybeMaskingAttr || useLinearMapping) &&
699+
"masking requires linear mapping");
700+
701+
GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
702+
*maybeMaskingAttr);
691703

692704
diag = mlir::transform::gpu::mapForallToBlocksImpl(
693705
rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +756,8 @@ static DiagnosedSilenceableFailure
744756
getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
745757
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
746758
int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
747-
auto mappingAttr = cast<DeviceMappingAttrInterface>(
748-
forallOp.getMapping()->getValue().front());
759+
DeviceMappingAttrInterface mappingAttr =
760+
forallOp.getDeviceMappingAttrs().front();
749761
bool useLinearMapping = mappingAttr.isLinearMapping();
750762

751763
// Sanity checks that may result in runtime verification errors.
@@ -768,21 +780,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
768780
if (!diag.succeeded())
769781
return diag;
770782

783+
FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
784+
forallOp.getDeviceMaskingAttr();
785+
assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
786+
assert((!*maybeMaskingAttr || useLinearMapping) &&
787+
"masking requires linear mapping");
788+
771789
// Start mapping.
772790
MLIRContext *ctx = forallOp.getContext();
773791
gpuIdBuilder =
774792
TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
775793
.Case([&](GPUWarpgroupMappingAttr) {
776-
return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
794+
return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
795+
*maybeMaskingAttr);
777796
})
778797
.Case([&](GPUWarpMappingAttr) {
779-
return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
798+
return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
799+
*maybeMaskingAttr);
780800
})
781801
.Case([&](GPUThreadMappingAttr) {
782-
return GpuThreadIdBuilder(ctx, useLinearMapping);
802+
return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr);
783803
})
784804
.Case([&](GPULaneMappingAttr) {
785-
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
805+
return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
806+
*maybeMaskingAttr);
786807
})
787808
.Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
788809
llvm_unreachable("unknown mapping attribute");

0 commit comments

Comments
 (0)