@@ -251,6 +251,253 @@ struct PrintOpLowering : public OpConversionPattern<PrintOp> {
251
251
}
252
252
};
253
253
254
+ // ===----------------------------------------------------------------------===//
255
+ // AccumulateOp.
256
+ // ===----------------------------------------------------------------------===//
257
+
258
+ // / Builds IR that opens the nested upstream iterator and sets `hasReturned` to
259
+ // / false. Possible output:
260
+ // /
261
+ // / %0 = iterators.extractvalue %arg0[0] :
262
+ // / <!upstream_state, i1> -> !upstream_state
263
+ // / %1 = call @iterators.upstream.open.0(%0) :
264
+ // / (!upstream_state) -> !upstream_state
265
+ // / %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
266
+ // / <!upstream_state, i1>
267
+ // / %false = arith.constant false
268
+ // / %3 = iterators.insertvalue %false into %2[1] :
269
+ // / !iterators.state<!upstream_state, i1>
270
+ static Value buildOpenBody (AccumulateOp op, OpBuilder &builder,
271
+ Value initialState,
272
+ ArrayRef<IteratorInfo> upstreamInfos) {
273
+ Location loc = op.getLoc ();
274
+ ImplicitLocOpBuilder b (loc, builder);
275
+
276
+ Type upstreamStateType = upstreamInfos[0 ].stateType ;
277
+
278
+ // Extract upstream state.
279
+ Value initialUpstreamState = b.create <iterators::ExtractValueOp>(
280
+ upstreamStateType, initialState, b.getIndexAttr (0 ));
281
+
282
+ // Call Open on upstream.
283
+ SymbolRefAttr openFunc = upstreamInfos[0 ].openFunc ;
284
+ auto openCallOp =
285
+ b.create <func::CallOp>(openFunc, upstreamStateType, initialUpstreamState);
286
+
287
+ // Update upstream state.
288
+ Value updatedUpstreamState = openCallOp->getResult (0 );
289
+ Value updatedState = b.create <iterators::InsertValueOp>(
290
+ initialState, b.getIndexAttr (0 ), updatedUpstreamState);
291
+
292
+ // Reset hasReturned to false.
293
+ Value constFalse = b.create <arith::ConstantIntOp>(/* value=*/ 0 , /* width=*/ 1 );
294
+ updatedState = b.create <iterators::InsertValueOp>(
295
+ updatedState, b.getIndexAttr (1 ), constFalse);
296
+
297
+ return updatedState;
298
+ }
299
+
300
+ // / Builds IR that consumes all elements of the upstream iterator and combines
301
+ // / them into a single one using the given accumulate function. Pseudo-code:
302
+ // /
303
+ // / if hasReturned: return {}
304
+ // / hasReturned = True
305
+ // / accumulator = initFuncRef()
306
+ // / while (next = upstream->Next()):
307
+ // / accumulator = accumulate(accumulator, next)
308
+ // / return accumulator
309
+ // /
310
+ // / Possible output:
311
+ // /
312
+ // / %0 = iterators.extractvalue %arg0[0] :
313
+ // / <!upstream_state, i1> -> !upstream_state
314
+ // / %1 = iterators.extractvalue %arg0[1] : !iterators.state<!upstream_state, i1>
315
+ // / %2:2 = scf.if %1 -> (!upstream_state, !element_type) {
316
+ // / %6 = llvm.mlir.undef : !element_type
317
+ // / scf.yield %0, %6 : !upstream_state, !element_type
318
+ // / } else {
319
+ // / %6 = func.call @zero_struct() : () -> !element_type
320
+ // / %7:3 = scf.while (%arg1 = %0, %arg2 = %6) :
321
+ // / (!upstream_state, !element_type) ->
322
+ // / (!upstream_state, !element_type, !element_type) {
323
+ // / %8:3 = func.call @iterators.upstream.next.0(%arg1) :
324
+ // / (!upstream_state) -> (!upstream_state, i1, !element_type)
325
+ // / scf.condition(%8#1) %8#0, %arg2, %8#2 :
326
+ // / !upstream_state, !element_type, !element_type
327
+ // // } do {
328
+ // / ^bb0(%arg1: !upstream_state, %arg2: !element_type, %arg3: !element_type):
329
+ // / %8 = func.call @accumulate_func(%arg2, %arg3) :
330
+ // / (!element_type, !element_type) -> !element_type
331
+ // / scf.yield %arg1, %8 : !upstream_state, !element_type
332
+ // / }
333
+ // / scf.yield %7#0, %7#1 : !upstream_state, !element_type
334
+ // / }
335
+ // / %3 = iterators.insertvalue %arg0[0] (%2#0 : !upstream_state) :
336
+ // / <!upstream_state, i1>
337
+ // / %true = arith.constant true
338
+ // / %4 = arith.xori %true, %1 : i1
339
+ // / %5 = iterators.insertvalue %true into %3[1] :
340
+ // / !iterators.state<!upstream_state, i1>
341
+ static llvm::SmallVector<Value, 4 >
342
+ buildNextBody (AccumulateOp op, OpBuilder &builder, Value initialState,
343
+ ArrayRef<IteratorInfo> upstreamInfos, Type elementType) {
344
+ Location loc = op.getLoc ();
345
+ ImplicitLocOpBuilder b (loc, builder);
346
+ Type i1 = b.getI1Type ();
347
+
348
+ // Extract input element type.
349
+ StreamType inputStreamType = op.input ().getType ().cast <StreamType>();
350
+ Type inputElementType = inputStreamType.getElementType ();
351
+
352
+ // Extract upstream state.
353
+ Type upstreamStateType = upstreamInfos[0 ].stateType ;
354
+ Value initialUpstreamState = b.create <iterators::ExtractValueOp>(
355
+ upstreamStateType, initialState, b.getIndexAttr (0 ));
356
+
357
+ // Check if the iterator has returned an element already (since it should
358
+ // return one only in the first call to next).
359
+ Value hasReturned =
360
+ b.create <iterators::ExtractValueOp>(i1, initialState, b.getIndexAttr (1 ));
361
+ TypeRange ifReturnTypes{upstreamStateType, elementType};
362
+ auto ifOp = b.create <scf::IfOp>(
363
+ ifReturnTypes, hasReturned,
364
+ /* thenBuilder=*/
365
+ [&](OpBuilder &builder, Location loc) {
366
+ ImplicitLocOpBuilder b (loc, builder);
367
+
368
+ // Don't modify state; return undef element.
369
+ Value nextElement = b.create <UndefOp>(elementType);
370
+ b.create <scf::YieldOp>(ValueRange{initialUpstreamState, nextElement});
371
+ },
372
+ /* elseBuilder=*/
373
+ [&](OpBuilder &builder, Location loc) {
374
+ ImplicitLocOpBuilder b (loc, builder);
375
+
376
+ // Initialize accumulator with init value.
377
+ FuncOp initFunc = op.getInitFunc ();
378
+ Value initValue = b.create <func::CallOp>(initFunc)->getResult (0 );
379
+
380
+ // Create while loop.
381
+ SmallVector<Value> whileInputs = {initialUpstreamState, initValue};
382
+ SmallVector<Type> whileResultTypes = {
383
+ upstreamStateType, // Updated upstream state.
384
+ elementType, // Accumulator.
385
+ inputElementType // Element from last next call.
386
+ };
387
+ scf::WhileOp whileOp = scf::createWhileOp (
388
+ b, whileResultTypes, whileInputs,
389
+ /* beforeBuilder=*/
390
+ [&](OpBuilder &builder, Location loc,
391
+ Block::BlockArgListType args) {
392
+ ImplicitLocOpBuilder b (loc, builder);
393
+
394
+ Value upstreamState = args[0 ];
395
+ Value accumulator = args[1 ];
396
+
397
+ // Call next function.
398
+ SmallVector<Type> nextResultTypes = {upstreamStateType, i1,
399
+ inputElementType};
400
+ SymbolRefAttr nextFunc = upstreamInfos[0 ].nextFunc ;
401
+ auto nextCall = b.create <func::CallOp>(nextFunc, nextResultTypes,
402
+ upstreamState);
403
+
404
+ Value updatedUpstreamState = nextCall->getResult (0 );
405
+ Value hasNext = nextCall->getResult (1 );
406
+ Value maybeNextElement = nextCall->getResult (2 );
407
+ b.create <scf::ConditionOp>(
408
+ hasNext, ValueRange{updatedUpstreamState, accumulator,
409
+ maybeNextElement});
410
+ },
411
+ /* afterBuilder=*/
412
+ [&](OpBuilder &builder, Location loc,
413
+ Block::BlockArgListType args) {
414
+ ImplicitLocOpBuilder b (loc, builder);
415
+
416
+ Value upstreamState = args[0 ];
417
+ Value accumulator = args[1 ];
418
+ Value nextElement = args[2 ];
419
+
420
+ // Call accumulate function.
421
+ auto accumulateCall =
422
+ b.create <func::CallOp>(elementType, op.accumulateFuncRef (),
423
+ ValueRange{accumulator, nextElement});
424
+ Value newAccumulator = accumulateCall->getResult (0 );
425
+
426
+ b.create <scf::YieldOp>(ValueRange{upstreamState, newAccumulator});
427
+ });
428
+
429
+ Value updatedState = whileOp->getResult (0 );
430
+ Value accumulator = whileOp->getResult (1 );
431
+
432
+ b.create <scf::YieldOp>(ValueRange{updatedState, accumulator});
433
+ });
434
+
435
+ // Compute hasNext: we have an element iff we have not returned before, i.e.,
436
+ // iff "not hasReturend". We simulate "not" with "xor true".
437
+ Value constTrue = b.create <arith::ConstantIntOp>(/* value=*/ 1 , /* width=*/ 1 );
438
+ Value hasNext = b.create <arith::XOrIOp>(constTrue, hasReturned);
439
+
440
+ // Update state.
441
+ Value finalUpstreamState = ifOp->getResult (0 );
442
+ Value finalState = b.create <iterators::InsertValueOp>(
443
+ initialState, b.getIndexAttr (0 ), finalUpstreamState);
444
+ finalState = b.create <iterators::InsertValueOp>(finalState, b.getIndexAttr (1 ),
445
+ constTrue);
446
+ Value nextElement = ifOp->getResult (1 );
447
+
448
+ return {finalState, hasNext, nextElement};
449
+ }
450
+
451
+ // / Builds IR that closes the nested upstream iterator. Possible output:
452
+ // /
453
+ // / %0 = iterators.extractvalue %arg0[0] :
454
+ // / !iterators.state<!upstream_state, i1> -> !upstream_state
455
+ // / %1 = call @iterators.upstream.close.0(%0) :
456
+ // / (!upstream_state) -> !upstream_state
457
+ // / %2 = iterators.insertvalue %arg0[0] (%1 : !upstream_state) :
458
+ // / !iterators.state<!upstream_state, i1>
459
+ static Value buildCloseBody (AccumulateOp op, OpBuilder &builder,
460
+ Value initialState,
461
+ ArrayRef<IteratorInfo> upstreamInfos) {
462
+ Location loc = op.getLoc ();
463
+ ImplicitLocOpBuilder b (loc, builder);
464
+
465
+ Type upstreamStateType = upstreamInfos[0 ].stateType ;
466
+
467
+ // Extract upstream state.
468
+ Value initialUpstreamState = b.create <iterators::ExtractValueOp>(
469
+ upstreamStateType, initialState, b.getIndexAttr (0 ));
470
+
471
+ // Call Close on upstream.
472
+ SymbolRefAttr closeFunc = upstreamInfos[0 ].closeFunc ;
473
+ auto closeCallOp = b.create <func::CallOp>(closeFunc, upstreamStateType,
474
+ initialUpstreamState);
475
+
476
+ // Update upstream state.
477
+ Value updatedUpstreamState = closeCallOp->getResult (0 );
478
+ return b
479
+ .create <iterators::InsertValueOp>(initialState, b.getIndexAttr (0 ),
480
+ updatedUpstreamState)
481
+ .getResult ();
482
+ }
483
+
484
+ // / Builds IR that initializes the iterator state with the state of the upstream
485
+ // / iterator. Possible output:
486
+ // /
487
+ // / %0 = ...
488
+ // / %1 = iterators.undefstate : <!upstream_state, i1>
489
+ // / %2 = iterators.insertvalue %1[0] (%0 : !upstream_state) :
490
+ // / !iterators.state<!upstream_state, i1>
491
+ static Value buildStateCreation (AccumulateOp op, AccumulateOp::Adaptor adaptor,
492
+ OpBuilder &builder, StateType stateType) {
493
+ Location loc = op.getLoc ();
494
+ ImplicitLocOpBuilder b (loc, builder);
495
+ Value undefState = b.create <UndefStateOp>(loc, stateType);
496
+ Value upstreamState = adaptor.input ();
497
+ return b.create <iterators::InsertValueOp>(undefState, b.getIndexAttr (0 ),
498
+ upstreamState);
499
+ }
500
+
254
501
// ===----------------------------------------------------------------------===//
255
502
// ConstantStreamOp.
256
503
// ===----------------------------------------------------------------------===//
@@ -1212,6 +1459,7 @@ static Value buildOpenBody(Operation *op, OpBuilder &builder,
1212
1459
return llvm::TypeSwitch<Operation *, Value>(op)
1213
1460
.Case <
1214
1461
// clang-format off
1462
+ AccumulateOp,
1215
1463
ConstantStreamOp,
1216
1464
FilterOp,
1217
1465
MapOp,
@@ -1230,6 +1478,7 @@ buildNextBody(Operation *op, OpBuilder &builder, Value initialState,
1230
1478
return llvm::TypeSwitch<Operation *, llvm::SmallVector<Value, 4 >>(op)
1231
1479
.Case <
1232
1480
// clang-format off
1481
+ AccumulateOp,
1233
1482
ConstantStreamOp,
1234
1483
FilterOp,
1235
1484
MapOp,
@@ -1249,6 +1498,7 @@ static Value buildCloseBody(Operation *op, OpBuilder &builder,
1249
1498
return llvm::TypeSwitch<Operation *, Value>(op)
1250
1499
.Case <
1251
1500
// clang-format off
1501
+ AccumulateOp,
1252
1502
ConstantStreamOp,
1253
1503
FilterOp,
1254
1504
MapOp,
@@ -1266,6 +1516,7 @@ static Value buildStateCreation(IteratorOpInterface op, OpBuilder &builder,
1266
1516
return llvm::TypeSwitch<Operation *, Value>(op)
1267
1517
.Case <
1268
1518
// clang-format off
1519
+ AccumulateOp,
1269
1520
ConstantStreamOp,
1270
1521
FilterOp,
1271
1522
MapOp,
0 commit comments