Skip to content

Commit 7f58ffd

Browse files
gryppmakslevental
andauthored
[mlir][python] Yield results of scf.for_ (#93610)
Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for. This PR yields results of scf.for in `for_` --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
1 parent 476a6d8 commit 7f58ffd

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def for_(
132132
iter_args = tuple(for_op.inner_iter_args)
133133
with InsertionPoint(for_op.body):
134134
if len(iter_args) > 1:
135-
yield iv, iter_args
135+
yield iv, iter_args, for_op.results
136136
elif len(iter_args) == 1:
137-
yield iv, iter_args[0]
137+
yield iv, iter_args[0], for_op.results[0]
138138
else:
139139
yield iv

mlir/test/python/dialects/scf.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,56 @@ def range_loop_7(lb, ub, step, memref_v):
176176
memref.store(add, memref_v, [i])
177177
scf.yield_([])
178178

179+
# CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
180+
# CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
181+
# CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
182+
# CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
183+
# CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
184+
# CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
185+
# CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
186+
# CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
187+
# CHECK: scf.yield %[[VAL_9]] : index
188+
# CHECK: }
189+
# CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
190+
# CHECK: return
191+
# CHECK: }
192+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
193+
def loop_yield_1(lb, ub, step, memref_v):
194+
sum = arith.ConstantOp.create_index(0)
195+
c0 = arith.ConstantOp.create_index(0)
196+
for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
197+
loc_sum = arith.addi(loc_sum, i)
198+
scf.yield_([loc_sum])
199+
memref.store(sum, memref_v, [c0])
200+
201+
# CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
202+
# CHECK: %[[c0:.*]] = arith.constant 0 : index
203+
# CHECK: %[[c2:.*]] = arith.constant 2 : index
204+
# CHECK: %[[REF1:.*]] = arith.constant 0 : index
205+
# CHECK: %[[REF2:.*]] = arith.constant 1 : index
206+
# CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
207+
# CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
208+
# CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
209+
# CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
210+
# CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
211+
# CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
212+
# CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
213+
# CHECK: }
214+
# CHECK: return
215+
# CHECK: }
216+
@func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
217+
def loop_yield_2(lb, ub, step, memref_v):
218+
sum1 = arith.ConstantOp.create_index(0)
219+
sum2 = arith.ConstantOp.create_index(2)
220+
c0 = arith.ConstantOp.create_index(0)
221+
c1 = arith.ConstantOp.create_index(1)
222+
for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
223+
loc_sum1 = arith.addi(loc_sum1, i)
224+
loc_sum2 = arith.addi(loc_sum2, i)
225+
scf.yield_([loc_sum1, loc_sum2])
226+
memref.store(sum1, memref_v, [c0])
227+
memref.store(sum2, memref_v, [c1])
228+
179229

180230
@constructAndPrintInModule
181231
def testOpsAsArguments():

0 commit comments

Comments
 (0)