Skip to content

Commit edbb0fb

Browse files
Iterators: Create iterator state type and extract/insert/undef ops. (#606)
1 parent 9bf20dc commit edbb0fb

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,4 +303,94 @@ def Iterators_SinkOp : Iterators_Base_Op<"sink"> {
303303
let arguments = (ins Iterators_StreamOfLLVMStructOfNumerics:$input);
304304
}
305305

306+
//===----------------------------------------------------------------------===//
307+
// Ops related to Iterator bodies.
308+
//===----------------------------------------------------------------------===//
309+
310+
def Iterators_UndefStateOp : Iterators_Base_Op<"undefstate", [NoSideEffect]> {
311+
let summary = "Create an undefined iterator state";
312+
let results = (outs Iterators_State:$result);
313+
let assemblyFormat = "attr-dict `:` qualified(type($result))";
314+
let description = [{
315+
Creates an iterator state of the given type with undefined field values.
316+
All fields have to be set individually with `insertvalue` before the whole
317+
state is fully defined.
318+
319+
This is similar to `llvm.undef` for `llvm.struct`.
320+
321+
Example:
322+
323+
```
324+
%undef_state = iterators.undefstate : !iterators.state<i32, tensor<?xi32>>
325+
```
326+
}];
327+
}
328+
329+
def Iterators_ExtractValueOp : Iterators_Base_Op<"extractvalue", [
330+
NoSideEffect,
331+
PredOpTrait<"index must exist in state",
332+
CPred<"static_cast<uint64_t>($index.cast<IntegerAttr>().getInt())"
333+
" < $state.getType().cast<StateType>().getFieldTypes().size()">>,
334+
AllMatch<["$state.getType().cast<StateType>().getFieldTypes()"
335+
" [$index.cast<IntegerAttr>().getInt()]",
336+
"$result.getType()"],
337+
"the return type must match the field type at the given index">,
338+
DeclareOpInterfaceMethods<InferTypeOpInterface>
339+
]> {
340+
let summary = "Extract the field value of the state";
341+
let arguments = (ins Iterators_State:$state, IndexAttr:$index);
342+
let results = (outs AnyType:$result);
343+
let assemblyFormat = [{
344+
$state `[` $index `]` attr-dict `:` qualified(type($state))
345+
}];
346+
let description = [{
347+
Extracts the value of the given iterator state at the given index.
348+
349+
This is similar to `llvm.extractvalue` for `llvm.struct`.
350+
351+
Example:
352+
353+
```
354+
%state = ...
355+
%value = iterators.extractvalue %state[0] :
356+
!iterators.state<i32, tensor<?xi32>>
357+
```
358+
}];
359+
}
360+
361+
def Iterators_InsertValueOp : Iterators_Base_Op<"insertvalue", [
362+
NoSideEffect,
363+
PredOpTrait<
364+
"index must exist in state",
365+
CPred<"static_cast<uint64_t>($index.cast<IntegerAttr>().getInt())"
366+
" < $state.getType().cast<StateType>().getFieldTypes().size()">>,
367+
AllMatch<["$state.getType().cast<StateType>().getFieldTypes()"
368+
" [$index.cast<IntegerAttr>().getInt()]",
369+
"$value.getType()"],
370+
"the value type must match the field type at the given index">,
371+
AllTypesMatch<["state", "result"]>
372+
]> {
373+
let summary = "Insert a field value into the state";
374+
let arguments = (ins Iterators_State:$state, IndexAttr:$index, AnyType:$value);
375+
let results = (outs Iterators_State:$result);
376+
let assemblyFormat = [{
377+
$value `into` $state `[` $index `]` attr-dict `:` qualified(type($state))
378+
custom<InsertValueType>(type($value), ref(type($state)), ref($index))
379+
}];
380+
let description = [{
381+
Inserts the given value into the given iterator state at the given index.
382+
383+
This is similar to `llvm.insertvalue` for `llvm.struct`.
384+
385+
Example:
386+
387+
```
388+
%state = ...
389+
%value = ...
390+
%updated_state = iterators.insertvalue %state[0] (%value : i32) :
391+
!iterators.state<i32, tensor<?xi32>>
392+
```
393+
}];
394+
}
395+
306396
#endif // ITERATORS_DIALECT_ITERATORS_IR_ITERATORSOPS

experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsTypes.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,32 @@ class Iterators_IsStreamOf<string name, Type type>
154154
class Iterators_IsStreamOfLLVMStructOfNumericsPred<string name>
155155
: Iterators_IsStreamOf<name, Iterators_StreamOfLLVMStructOfNumerics>;
156156

157+
//===----------------------------------------------------------------------===//
158+
// Types related to Iterator bodies.
159+
//===----------------------------------------------------------------------===//
160+
161+
def Iterators_State : Iterators_Type<"State", "state"> {
162+
let summary = "State of an iterator used by its body";
163+
let parameters = (ins ArrayRefParameter<"Type", "list of types">:$fieldTypes);
164+
let assemblyFormat = "`<` qualified($fieldTypes) `>`";
165+
let description = [{
166+
An iterator state is a collection of values identified by ordinal numbers,
167+
i.e., an (unnamed but typed) tuple. The values are referred to as "fields";
168+
their types are referred to as "field types". An iterator state is used by
169+
iterator bodies, i.e., by the open, next, and close functions that implement
170+
the logic that iterator ops get lowered to, and holds the state that is
171+
required during the iteration (which gets passed around different calls to
172+
open, next, and close).
173+
174+
This is similar to (anonymous) `llvm.struct` but allows for storing values
175+
of arbitrary types.
176+
177+
Example:
178+
179+
```
180+
%undef_state = iterators.undefstate : !iterators.state<i32, tensor<?xi32>>
181+
```
182+
}];
183+
}
184+
157185
#endif // ITERATORS_DIALECT_ITERATORS_IR_ITERATORSTYPES

experimental/iterators/lib/Dialects/Iterators/IR/Iterators.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,33 @@ void IteratorsDialect::initialize() {
4545
// Iterators operations
4646
//===----------------------------------------------------------------------===//
4747

48+
static ParseResult parseInsertValueType(AsmParser & /*parser*/, Type &valueType,
49+
Type stateType, IntegerAttr indexAttr) {
50+
int64_t index = indexAttr.getValue().getSExtValue();
51+
auto castedStateType = stateType.cast<StateType>();
52+
valueType = castedStateType.getFieldTypes()[index];
53+
return success();
54+
}
55+
56+
static void printInsertValueType(AsmPrinter & /*printer*/, Operation * /*op*/,
57+
Type /*valueType*/, Type /*stateType*/,
58+
IntegerAttr /*indexAttr*/) {}
59+
4860
#define GET_OP_CLASSES
4961
#include "iterators/Dialect/Iterators/IR/IteratorsOps.cpp.inc"
5062

63+
LogicalResult ExtractValueOp::inferReturnTypes(
64+
MLIRContext * /*context*/, Optional<Location> location, ValueRange operands,
65+
DictionaryAttr attributes, RegionRange regions,
66+
SmallVectorImpl<Type> &inferredReturnTypes) {
67+
auto stateType = operands[0].getType().cast<StateType>();
68+
auto indexAttr = attributes.getAs<IntegerAttr>("index");
69+
int64_t index = indexAttr.getValue().getSExtValue();
70+
Type fieldType = stateType.getFieldTypes()[index];
71+
inferredReturnTypes.assign({fieldType});
72+
return success();
73+
}
74+
5175
//===----------------------------------------------------------------------===//
5276
// Iterators types
5377
//===----------------------------------------------------------------------===//
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Test that we can parse and verify ops on iterator state correctly, and that
2+
// they round-trip through assembly.
3+
// RUN: mlir-proto-opt %s \
4+
// RUN: | FileCheck %s
5+
6+
func.func @testUndefInsertExtract() {
7+
// CHECK-LABEL: func.func @testUndefInsertExtract() {
8+
%initial_state = iterators.undefstate : !iterators.state<i32>
9+
// CHECK-NEXT: %[[V0:.*]] = iterators.undefstate : !iterators.state<i32>
10+
%value = arith.constant 0 : i32
11+
// CHECK-NEXT: %[[V1:.*]] = arith.constant 0 : i32
12+
%inserted_state = iterators.insertvalue %value into %initial_state[0] : !iterators.state<i32>
13+
// CHECK-NEXT: %[[V2:.*]] = iterators.insertvalue %[[V1]] into %[[V0]][0] : !iterators.state<i32>
14+
%extracted_value = iterators.extractvalue %inserted_state[0] : !iterators.state<i32>
15+
// CHECK-NEXT: %[[V3:.*]] = iterators.extractvalue %[[V2]][0] : !iterators.state<i32>
16+
return
17+
// CHECK-NEXT: return
18+
}
19+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)