Skip to content

Commit 42e50cb

Browse files
Implement lowering for AccumulateOp.
1 parent 31f10c7 commit 42e50cb

File tree

4 files changed

+421
-0
lines changed

4 files changed

+421
-0
lines changed

experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorAnalysis.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ class StateTypeComputer {
5757
TypeConverter typeConverter;
5858
};
5959

60+
/// The state of AccumulateOp consists of the state of its upstream iterator,
61+
/// i.e., the state of the iterator that produces its input stream, and a
62+
/// Boolean indicating whether the iterator has returned a result already (which
63+
/// is initialized to false and set to true in the first call to next in order
64+
/// to ensure that only a single result is returned).
65+
template <>
66+
StateType
67+
StateTypeComputer::operator()(AccumulateOp op,
68+
llvm::SmallVector<StateType> upstreamStateTypes) {
69+
MLIRContext *context = op->getContext();
70+
Type hasReturned = IntegerType::get(context, /*width=*/1);
71+
return StateType::get(context, {upstreamStateTypes[0], hasReturned});
72+
}
73+
6074
/// The state of ConstantStreamOp consists of a single number that corresponds
6175
/// to the index of the next struct returned by the iterator.
6276
template <>
@@ -155,6 +169,7 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
155169
// TODO: Verify that operands do not come from bbArgs.
156170
.Case<
157171
// clang-format off
172+
AccumulateOp,
158173
ConstantStreamOp,
159174
FilterOp,
160175
MapOp,

experimental/iterators/lib/Conversion/IteratorsToLLVM/IteratorsToLLVM.cpp

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,253 @@ struct PrintOpLowering : public OpConversionPattern<PrintOp> {
251251
}
252252
};
253253

254+
//===----------------------------------------------------------------------===//
255+
// AccumulateOp.
256+
//===----------------------------------------------------------------------===//
257+
258+
/// Builds IR that opens the nested upstream iterator and sets `hasReturned` to
259+
/// false. Possible output:
260+
///
261+
/// %0 = iterators.extractvalue %arg0[0] :
262+
/// <!upstream_state, i1> -> !upstream_state
263+
/// %1 = call @iterators.upstream.open.0(%0) :
264+
/// (!upstream_state) -> !upstream_state
265+
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
266+
/// <!upstream_state, i1>
267+
/// %false = arith.constant false
268+
/// %3 = iterators.insertvalue %false into %2[1] :
269+
/// !iterators.state<!upstream_state, i1>
270+
static Value buildOpenBody(AccumulateOp op, OpBuilder &builder,
271+
Value initialState,
272+
ArrayRef<IteratorInfo> upstreamInfos) {
273+
Location loc = op.getLoc();
274+
ImplicitLocOpBuilder b(loc, builder);
275+
276+
Type upstreamStateType = upstreamInfos[0].stateType;
277+
278+
// Extract upstream state.
279+
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
280+
upstreamStateType, initialState, b.getIndexAttr(0));
281+
282+
// Call Open on upstream.
283+
SymbolRefAttr openFunc = upstreamInfos[0].openFunc;
284+
auto openCallOp =
285+
b.create<func::CallOp>(openFunc, upstreamStateType, initialUpstreamState);
286+
287+
// Update upstream state.
288+
Value updatedUpstreamState = openCallOp->getResult(0);
289+
Value updatedState = b.create<iterators::InsertValueOp>(
290+
initialState, b.getIndexAttr(0), updatedUpstreamState);
291+
292+
// Reset hasReturned to false.
293+
Value constFalse = b.create<arith::ConstantIntOp>(/*value=*/0, /*width=*/1);
294+
updatedState = b.create<iterators::InsertValueOp>(
295+
updatedState, b.getIndexAttr(1), constFalse);
296+
297+
return updatedState;
298+
}
299+
300+
/// Builds IR that consumes all elements of the upstream iterator and combines
301+
/// them into a single one using the given accumulate function. Pseudo-code:
302+
///
303+
/// if hasReturned: return {}
304+
/// hasReturned = True
305+
/// accumulator = initFuncRef()
306+
/// while (next = upstream->Next()):
307+
/// accumulator = accumulate(accumulator, next)
308+
/// return accumulator
309+
///
310+
/// Possible output:
311+
///
312+
/// %0 = iterators.extractvalue %arg0[0] :
313+
/// <!upstream_state, i1> -> !upstream_state
314+
/// %1 = iterators.extractvalue %arg0[1] : !iterators.state<!upstream_state, i1>
315+
/// %2:2 = scf.if %1 -> (!upstream_state, !element_type) {
316+
/// %6 = llvm.mlir.undef : !element_type
317+
/// scf.yield %0, %6 : !upstream_state, !element_type
318+
/// } else {
319+
/// %6 = func.call @zero_struct() : () -> !element_type
320+
/// %7:3 = scf.while (%arg1 = %0, %arg2 = %6) :
321+
/// (!upstream_state, !element_type) ->
322+
/// (!upstream_state, !element_type, !element_type) {
323+
/// %8:3 = func.call @iterators.upstream.next.0(%arg1) :
324+
/// (!upstream_state) -> (!upstream_state, i1, !element_type)
325+
/// scf.condition(%8#1) %8#0, %arg2, %8#2 :
326+
/// !upstream_state, !element_type, !element_type
327+
//// } do {
328+
/// ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type):
329+
/// %8 = func.call @accumulate_func(%arg2, %arg3) :
330+
/// (!element_type, !element_type) -> !element_type
331+
/// scf.yield %arg1, %8 : !upstream_state, !element_type
332+
/// }
333+
/// scf.yield %7#0, %7#1 : !upstream_state, !element_type
334+
/// }
335+
/// %3 = iterators.insertvalue %arg0[0] (%2#0 : !upstream_state) :
336+
/// <!upstream_state, i1>
337+
/// %true = arith.constant true
338+
/// %4 = arith.xori %true, %1 : i1
339+
/// %5 = iterators.insertvalue %true into %3[1] :
340+
/// !iterators.state<!upstream_state, i1>
341+
static llvm::SmallVector<Value, 4>
342+
buildNextBody(AccumulateOp op, OpBuilder &builder, Value initialState,
343+
ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
344+
Location loc = op.getLoc();
345+
ImplicitLocOpBuilder b(loc, builder);
346+
Type i1 = b.getI1Type();
347+
348+
// Extract input element type.
349+
StreamType inputStreamType = op.input().getType().cast<StreamType>();
350+
Type inputElementType = inputStreamType.getElementType();
351+
352+
// Extract upstream state.
353+
Type upstreamStateType = upstreamInfos[0].stateType;
354+
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
355+
upstreamStateType, initialState, b.getIndexAttr(0));
356+
357+
// Check if the iterator has returned an element already (since it should
358+
// return one only in the first call to next).
359+
Value hasReturned =
360+
b.create<iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr(1));
361+
TypeRange ifReturnTypes{upstreamStateType, elementType};
362+
auto ifOp = b.create<scf::IfOp>(
363+
ifReturnTypes, hasReturned,
364+
/*thenBuilder=*/
365+
[&](OpBuilder &builder, Location loc) {
366+
ImplicitLocOpBuilder b(loc, builder);
367+
368+
// Don't modify state; return undef element.
369+
Value nextElement = b.create<UndefOp>(elementType);
370+
b.create<scf::YieldOp>(ValueRange{initialUpstreamState, nextElement});
371+
},
372+
/*elseBuilder=*/
373+
[&](OpBuilder &builder, Location loc) {
374+
ImplicitLocOpBuilder b(loc, builder);
375+
376+
// Initialize accumulator with init value.
377+
FuncOp initFunc = op.getInitFunc();
378+
Value initValue = b.create<func::CallOp>(initFunc)->getResult(0);
379+
380+
// Create while loop.
381+
SmallVector<Value> whileInputs = {initialUpstreamState, initValue};
382+
SmallVector<Type> whileResultTypes = {
383+
upstreamStateType, // Updated upstream state.
384+
elementType, // Accumulator.
385+
inputElementType // Element from last next call.
386+
};
387+
scf::WhileOp whileOp = scf::createWhileOp(
388+
b, whileResultTypes, whileInputs,
389+
/*beforeBuilder=*/
390+
[&](OpBuilder &builder, Location loc,
391+
Block::BlockArgListType args) {
392+
ImplicitLocOpBuilder b(loc, builder);
393+
394+
Value upstreamState = args[0];
395+
Value accumulator = args[1];
396+
397+
// Call next function.
398+
SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
399+
inputElementType};
400+
SymbolRefAttr nextFunc = upstreamInfos[0].nextFunc;
401+
auto nextCall = b.create<func::CallOp>(nextFunc, nextResultTypes,
402+
upstreamState);
403+
404+
Value updatedUpstreamState = nextCall->getResult(0);
405+
Value hasNext = nextCall->getResult(1);
406+
Value maybeNextElement = nextCall->getResult(2);
407+
b.create<scf::ConditionOp>(
408+
hasNext, ValueRange{updatedUpstreamState, accumulator,
409+
maybeNextElement});
410+
},
411+
/*afterBuilder=*/
412+
[&](OpBuilder &builder, Location loc,
413+
Block::BlockArgListType args) {
414+
ImplicitLocOpBuilder b(loc, builder);
415+
416+
Value upstreamState = args[0];
417+
Value accumulator = args[1];
418+
Value nextElement = args[2];
419+
420+
// Call accumulate function.
421+
auto accumulateCall =
422+
b.create<func::CallOp>(elementType, op.accumulateFuncRef(),
423+
ValueRange{accumulator, nextElement});
424+
Value newAccumulator = accumulateCall->getResult(0);
425+
426+
b.create<scf::YieldOp>(ValueRange{upstreamState, newAccumulator});
427+
});
428+
429+
Value updatedState = whileOp->getResult(0);
430+
Value accumulator = whileOp->getResult(1);
431+
432+
b.create<scf::YieldOp>(ValueRange{updatedState, accumulator});
433+
});
434+
435+
// Compute hasNext: we have an element iff we have not returned before, i.e.,
436+
// iff "not hasReturend". We simulate "not" with "xor true".
437+
Value constTrue = b.create<arith::ConstantIntOp>(/*value=*/1, /*width=*/1);
438+
Value hasNext = b.create<arith::XOrIOp>(constTrue, hasReturned);
439+
440+
// Update state.
441+
Value finalUpstreamState = ifOp->getResult(0);
442+
Value finalState = b.create<iterators::InsertValueOp>(
443+
initialState, b.getIndexAttr(0), finalUpstreamState);
444+
finalState = b.create<iterators::InsertValueOp>(finalState, b.getIndexAttr(1),
445+
constTrue);
446+
Value nextElement = ifOp->getResult(1);
447+
448+
return {finalState, hasNext, nextElement};
449+
}
450+
451+
/// Builds IR that closes the nested upstream iterator. Possible output:
452+
///
453+
/// %0 = iterators.extractvalue %arg0[0] :
454+
/// !iterators.state<!upstream_state, i1> -> !upstream_state
455+
/// %1 = call @iterators.upstream.close.0(%0) :
456+
/// (!upstream_state) -> !upstream_state
457+
/// %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
458+
/// !iterators.state<!upstream_state, i1>
459+
static Value buildCloseBody(AccumulateOp op, OpBuilder &builder,
460+
Value initialState,
461+
ArrayRef<IteratorInfo> upstreamInfos) {
462+
Location loc = op.getLoc();
463+
ImplicitLocOpBuilder b(loc, builder);
464+
465+
Type upstreamStateType = upstreamInfos[0].stateType;
466+
467+
// Extract upstream state.
468+
Value initialUpstreamState = b.create<iterators::ExtractValueOp>(
469+
upstreamStateType, initialState, b.getIndexAttr(0));
470+
471+
// Call Close on upstream.
472+
SymbolRefAttr closeFunc = upstreamInfos[0].closeFunc;
473+
auto closeCallOp = b.create<func::CallOp>(closeFunc, upstreamStateType,
474+
initialUpstreamState);
475+
476+
// Update upstream state.
477+
Value updatedUpstreamState = closeCallOp->getResult(0);
478+
return b
479+
.create<iterators::InsertValueOp>(initialState, b.getIndexAttr(0),
480+
updatedUpstreamState)
481+
.getResult();
482+
}
483+
484+
/// Builds IR that initializes the iterator state with the state of the upstream
485+
/// iterator. Possible output:
486+
///
487+
/// %0 = ...
488+
/// %1 = iterators.undefstate : <!upstream_state, i1>
489+
/// %2 = iterators.insertvalue %1[0] (%0 : !upstream_state) :
490+
/// !iterators.state<!upstream_state, i1>
491+
static Value buildStateCreation(AccumulateOp op, AccumulateOp::Adaptor adaptor,
492+
OpBuilder &builder, StateType stateType) {
493+
Location loc = op.getLoc();
494+
ImplicitLocOpBuilder b(loc, builder);
495+
Value undefState = b.create<UndefStateOp>(loc, stateType);
496+
Value upstreamState = adaptor.input();
497+
return b.create<iterators::InsertValueOp>(undefState, b.getIndexAttr(0),
498+
upstreamState);
499+
}
500+
254501
//===----------------------------------------------------------------------===//
255502
// ConstantStreamOp.
256503
//===----------------------------------------------------------------------===//
@@ -1212,6 +1459,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
12121459
return llvm::TypeSwitch<Operation *, Value>(op)
12131460
.Case<
12141461
// clang-format off
1462+
AccumulateOp,
12151463
ConstantStreamOp,
12161464
FilterOp,
12171465
MapOp,
@@ -1230,6 +1478,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
12301478
return llvm::TypeSwitch<Operation *, llvm::SmallVector<Value, 4>>(op)
12311479
.Case<
12321480
// clang-format off
1481+
AccumulateOp,
12331482
ConstantStreamOp,
12341483
FilterOp,
12351484
MapOp,
@@ -1249,6 +1498,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
12491498
return llvm::TypeSwitch<Operation *, Value>(op)
12501499
.Case<
12511500
// clang-format off
1501+
AccumulateOp,
12521502
ConstantStreamOp,
12531503
FilterOp,
12541504
MapOp,
@@ -1266,6 +1516,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
12661516
return llvm::TypeSwitch<Operation *, Value>(op)
12671517
.Case<
12681518
// clang-format off
1519+
AccumulateOp,
12691520
ConstantStreamOp,
12701521
FilterOp,
12711522
MapOp,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: mlir-proto-opt %s -convert-iterators-to-llvm \
2+
// RUN: | FileCheck --enable-var-scope %s
3+
4+
!element_type = !llvm.struct<(i32)>
5+
6+
func.func private @zero_struct() -> !element_type {
7+
%zero = arith.constant 0 : i32
8+
%undef = llvm.mlir.undef : !element_type
9+
%result = llvm.insertvalue %zero, %undef[0 : index] : !element_type
10+
return %result : !element_type
11+
}
12+
13+
func.func private @sum_struct(%lhs : !element_type, %rhs : !element_type) -> !element_type {
14+
%lhsi = llvm.extractvalue %lhs[0 : index] : !element_type
15+
%rhsi = llvm.extractvalue %rhs[0 : index] : !element_type
16+
%i = arith.addi %lhsi, %rhsi : i32
17+
%result = llvm.insertvalue %i, %lhs[0 : index] : !element_type
18+
return %result : !element_type
19+
}
20+
21+
// CHECK-LABEL: func.func private @iterators.accumulate.next.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>, i1>) -> (!iterators.state<!iterators.state<i32>, i1>, i1, !llvm.struct<(i32)>) {
22+
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>, i1>
23+
// CHECK-NEXT: %[[V2:.*]] = iterators.extractvalue %[[arg0]][1] : !iterators.state<!iterators.state<i32>, i1>
24+
// CHECK-NEXT: %[[V3:.*]]:2 = scf.if %[[V2]] -> (!iterators.state<i32>, !llvm.struct<(i32)>) {
25+
// CHECK-NEXT: %[[V4:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
26+
// CHECK-NEXT: scf.yield %[[V1]], %[[V4]] : !iterators.state<i32>, !llvm.struct<(i32)>
27+
// CHECK-NEXT: } else {
28+
// CHECK-NEXT: %[[V4:.*]] = func.call @zero_struct() : () -> !llvm.struct<(i32)>
29+
// CHECK-NEXT: %[[V5:.*]]:3 = scf.while (%[[arg1:.*]] = %[[V1]], %[[arg2:.*]] = %[[V4]]) : (!iterators.state<i32>, !llvm.struct<(i32)>) -> (!iterators.state<i32>, !llvm.struct<(i32)>, !llvm.struct<(i32)>) {
30+
// CHECK-NEXT: %[[V6:.*]]:3 = func.call @iterators.constantstream.next.0(%[[arg1]]) : (!iterators.state<i32>) -> (!iterators.state<i32>, i1, !llvm.struct<(i32)>)
31+
// CHECK-NEXT: scf.condition(%[[V6]]#1) %[[V6]]#0, %[[arg2]], %[[V6]]#2 : !iterators.state<i32>, !llvm.struct<(i32)>, !llvm.struct<(i32)>
32+
// CHECK-NEXT: } do {
33+
// CHECK-NEXT: ^bb0(%[[arg1:.*]]: !iterators.state<i32>, %[[arg2:.*]]: !llvm.struct<(i32)>, %[[arg3:.*]]: !llvm.struct<(i32)>):
34+
// CHECK-NEXT: %[[V7:.*]] = func.call @sum_struct(%[[arg2]], %[[arg3]]) : (!llvm.struct<(i32)>, !llvm.struct<(i32)>) -> !llvm.struct<(i32)>
35+
// CHECK-NEXT: scf.yield %[[arg1]], %[[V7]] : !iterators.state<i32>, !llvm.struct<(i32)>
36+
// CHECK-NEXT: }
37+
// CHECK-NEXT: scf.yield %[[V5]]#0, %[[V5]]#1 : !iterators.state<i32>, !llvm.struct<(i32)>
38+
// CHECK-NEXT: }
39+
// CHECK-NEXT: %[[V8:.*]] = arith.constant true
40+
// CHECK-NEXT: %[[V9:.*]] = arith.xori %[[V8]], %[[V2]] : i1
41+
// CHECK-NEXT: %[[Va:.*]] = iterators.insertvalue %[[V3]]#0 into %[[arg0]][0] : !iterators.state<!iterators.state<i32>, i1>
42+
// CHECK-NEXT: %[[Vb:.*]] = iterators.insertvalue %[[V8]] into %[[Va]][1] : !iterators.state<!iterators.state<i32>, i1>
43+
// CHECK-NEXT: return %[[Vb]], %[[V9]], %[[V3]]#1 : !iterators.state<!iterators.state<i32>, i1>, i1, !llvm.struct<(i32)>
44+
// CHECK-NEXT: }
45+
46+
// CHECK-LABEL: func.func private @iterators.accumulate.open.{{[0-9]+}}(%{{.*}}: !iterators.state<!iterators.state<i32>, i1>) -> !iterators.state<!iterators.state<i32>, i1> {
47+
// CHECK-NEXT: %[[V1:.*]] = iterators.extractvalue %[[arg0:.*]][0] : !iterators.state<!iterators.state<i32>, i1>
48+
// CHECK-NEXT: %[[V2:.*]] = call @iterators.constantstream.open.0(%[[V1]]) : (!iterators.state<i32>) -> !iterators.state<i32>
49+
// CHECK-NEXT: %[[V3:.*]] = iterators.insertvalue %[[V2]] into %[[arg0]][0] : !iterators.state<!iterators.state<i32>, i1>
50+
// CHECK-NEXT: %[[V4:.*]] = arith.constant false
51+
// CHECK-NEXT: %[[V5:.*]] = iterators.insertvalue %[[V4]] into %[[V3]][1] : !iterators.state<!iterators.state<i32>, i1>
52+
// CHECK-NEXT: return %[[V5]] : !iterators.state<!iterators.state<i32>, i1>
53+
// CHECK-NEXT: }
54+
55+
func.func @main() {
56+
// CHECK-LABEL: func.func @main()
57+
%input = "iterators.constantstream"() { value = [] } : () -> (!iterators.stream<!element_type>)
58+
%accumulated = iterators.accumulate(%input, @zero_struct, @sum_struct)
59+
: (!iterators.stream<!element_type>) -> !iterators.stream<!element_type>
60+
// CHECK: %[[V1:.*]] = iterators.undefstate : !iterators.state<!iterators.state<i32>, i1>
61+
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V0:.*]] into %[[V1]][0] : !iterators.state<!iterators.state<i32>, i1>
62+
return
63+
// CHECK-NEXT: return
64+
}

0 commit comments

Comments
 (0)