Skip to content

Commit 881e59a

Browse files
sjw36njriasan
authored andcommitted
[PIPELINER] Use AttrHelper (mostly NFC) (triton-lang#6437)
* Converted ad-hoc attribute handling to AttrHelpers to ensure consistency * Changed `tt_latency` to `tt.latency` in AssignLatencies?
1 parent fcef033 commit 881e59a

File tree

7 files changed

+33
-22
lines changed

7 files changed

+33
-22
lines changed

include/triton/Dialect/Triton/IR/TritonDialect.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,20 @@ def Triton_Dialect : Dialect {
3434

3535
let extraClassDeclaration = [{
3636
void registerTypes();
37+
38+
static TritonDialect *getLoaded(MLIRContext *ctx) {
39+
return ctx->getLoadedDialect<TritonDialect>();
40+
}
41+
static TritonDialect *getLoaded(Operation *op) {
42+
return getLoaded(op->getContext());
43+
}
3744
}];
3845

46+
let discardableAttrs = (ins
47+
"::mlir::IntegerAttr":$num_stages,
48+
"::mlir::IntegerAttr":$latency
49+
);
50+
3951
let hasConstantMaterializer = 1;
4052
let useDefaultTypePrinterParser = 1;
4153
let usePropertiesForAttributes = 1;

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22-
static const char *kLatencyAttrName = "tt.latency";
2322

2423
//===----------------------------------------------------------------------===//
2524
// Hoisting Utilities

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,20 @@ bool preCondition(scf::ForOp forOp) {
3737
}
3838

3939
bool hasLatenciesAssigned(scf::ForOp forOp) {
40+
auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper();
4041
for (auto &op : forOp.getBody()->without_terminator()) {
41-
if (op.hasAttr("tt_latency"))
42+
if (helper.getAttr(&op))
4243
return true;
4344
}
4445
return false;
4546
}
4647

4748
void assignUserProvidedLatencies(scf::ForOp forOp,
4849
DenseMap<Operation *, int> &opLatency) {
50+
auto helper = TritonDialect::getLoaded(forOp)->getLatencyAttrHelper();
4951
for (auto &op : forOp.getBody()->without_terminator()) {
50-
if (auto latencyAttr = op.getAttr("tt_latency")) {
51-
opLatency[&op] = mlir::cast<IntegerAttr>(latencyAttr).getInt();
52+
if (auto latencyAttr = helper.getAttr(&op)) {
53+
opLatency[&op] = latencyAttr.getInt();
5254
}
5355
}
5456
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,19 +325,20 @@ int mlir::triton::getCopyVecBytes(RankedTensorType registerTy,
325325

326326
void mlir::triton::serializeLatencies(ModuleOp module,
327327
DenseMap<Operation *, int> &opLatency) {
328+
auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper();
329+
auto builder = Builder(module);
328330
for (auto &[op, latency] : opLatency) {
329-
op->setAttr(
330-
kLatencyAttrName,
331-
IntegerAttr::get(IntegerType::get(module.getContext(), 32), latency));
331+
helper.setAttr(op, builder.getI32IntegerAttr(latency));
332332
}
333333
}
334334

335335
DenseMap<Operation *, int> mlir::triton::deserializeLatencies(Operation *op) {
336+
auto helper = TritonDialect::getLoaded(op)->getLatencyAttrHelper();
336337
DenseMap<Operation *, int> opLatency;
337338
op->walk([&](Operation *op) {
338-
if (op->hasAttr(kLatencyAttrName)) {
339-
opLatency[op] = op->getAttrOfType<IntegerAttr>(kLatencyAttrName).getInt();
340-
op->removeAttr(kLatencyAttrName);
339+
if (auto attr = helper.getAttr(op)) {
340+
opLatency[op] = attr.getInt();
341+
helper.removeAttr(op);
341342
}
342343
});
343344
return opLatency;
@@ -519,9 +520,8 @@ int mlir::triton::getNumStagesOrDefault(scf::ForOp forOp,
519520
int defaultNumStages) {
520521
// Use the attribute attached to the loop if it exists otherwise use the
521522
// global control.
522-
if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName))
523-
return defaultNumStages;
524-
return mlir::cast<IntegerAttr>(
525-
forOp->getAttr(mlir::triton::kNumStagesAttrName))
526-
.getInt();
523+
auto helper = TritonDialect::getLoaded(forOp)->getNumStagesAttrHelper();
524+
if (auto attr = helper.getAttr(forOp))
525+
return attr.getInt();
526+
return defaultNumStages;
527527
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/TestPipelineScheduleLoop.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ namespace gpu {
1515
#define GEN_PASS_DEF_TRITONGPUTESTPIPELINESCHEDULELOOP
1616
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
1717

18-
static const char *kLatencyAttrName = "tt.latency";
19-
2018
struct TestPipelineScheduleLoop
2119
: public impl::TritonGPUTestPipelineScheduleLoopBase<
2220
TestPipelineScheduleLoop> {

test/TritonGPU/loop-pipeline-async-latencies.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ tt.func public @matmul_kernel_tma_persistent(%arg0: !tt.ptr<i8, 0> {tt.nv_tma_de
101101
// CHECK: [[RHS_MBAR:%.*]] = ttg.memdesc_subview [[RHS_BARS]][[[RHS_BUF_IDX]]]
102102
// CHECK-NEXT: ttng.wait_barrier [[RHS_MBAR]], [[RHS_PHASE]]
103103

104-
%4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt_latency = 1 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked>
104+
%4 = tt.descriptor_load %1[%c0_i32, %arg6] {tt.latency = 1 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked>
105105
%5 = ttg.local_alloc %4 : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
106-
%6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt_latency = 3 : i32} : !tt.tensordesc<tensor<256x64xf16, #shared>> -> tensor<256x64xf16, #blocked>
106+
%6 = tt.descriptor_load %2[%c0_i32, %arg6] {tt.latency = 3 : i32} : !tt.tensordesc<tensor<256x64xf16, #shared>> -> tensor<256x64xf16, #blocked>
107107
%7 = ttg.local_alloc %6 : (tensor<256x64xf16, #blocked>) -> !ttg.memdesc<256x64xf16, #shared, #smem>
108108
%8 = ttg.memdesc_trans %7 {order = array<i32: 1, 0>} : !ttg.memdesc<256x64xf16, #shared, #smem> -> !ttg.memdesc<64x256xf16, #shared1, #smem>
109109
%9 = ttng.warp_group_dot %5, %8, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #shared, #smem> * !ttg.memdesc<64x256xf16, #shared1, #smem> -> tensor<128x256xf32, #mma>

test/TritonGPU/loop-schedule.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ tt.func @prologue_backward_slice(%ub: i32, %cond: i1) {
168168
// CHECK: op.with_region
169169
"op.with_region"() ({
170170
"use"(%1) : (i32) -> ()
171-
}) {tt_latency = 2 : i32} : () -> ()
171+
}) {tt.latency = 2 : i32} : () -> ()
172172
// CHECK: loop.cluster = 1 : i32, loop.stage = 0 : i32
173173

174174
} {tt.num_stages = 3 : i32}
@@ -186,7 +186,7 @@ tt.func @epilogue_forward_slice(%ub: i32, %cond: i1) {
186186
// CHECK: scf.for
187187
scf.for %i = %c0_i32 to %ub step %c1_i32 : i32 {
188188
// CHECK: "latency.op"() {loop.cluster = 3 : i32, loop.stage = 0 : i32
189-
%0 = "latency.op"() {tt_latency = 2 : i32} : () -> i32
189+
%0 = "latency.op"() {tt.latency = 2 : i32} : () -> i32
190190
// CHECK: scf.if
191191
%1 = scf.if %cond -> i32 {
192192
scf.yield %0 : i32
@@ -219,7 +219,7 @@ tt.func @prologue_latency(%ub: i32, %cond: i1) {
219219
scf.yield %0 : i32
220220
} else {
221221
scf.yield %c0_i32 : i32
222-
} {tt_latency = 2 : i32}
222+
} {tt.latency = 2 : i32}
223223
// CHECK: loop.cluster = 0 : i32, loop.stage = 0 : i32
224224

225225
} {tt.num_stages = 3 : i32}

0 commit comments

Comments
 (0)