@@ -308,3 +308,57 @@ def stop_gradient(variable):
308
308
def unstack (x , num = None , axis = 0 ):
309
309
y = x .split (num or x .shape [axis ], axis = axis )
310
310
return [yi .squeeze (axis ) for yi in y ]
311
+
312
+
313
+ def reverse_sequence (xs ):
314
+ indices = mx .arange (xs .shape [0 ] - 1 , - 1 , - 1 )
315
+ return mx .take (xs , indices , axis = 0 )
316
+
317
+
318
+ def scan (f , init , xs , reverse = False , mask = None ):
319
+ states = init
320
+ outputs_list = []
321
+
322
+ if mask is not None :
323
+ x , mask = xs
324
+ if reverse :
325
+ x = reverse_sequence (x )
326
+ mask = reverse_sequence (mask )
327
+ iterator = zip (x , mask )
328
+ else :
329
+ if reverse :
330
+ if isinstance (xs , tuple ):
331
+ xs = tuple (reverse_sequence (x ) for x in xs )
332
+ else :
333
+ xs = reverse_sequence (xs )
334
+ iterator = zip (* xs ) if isinstance (xs , tuple ) else xs
335
+
336
+ for x in iterator :
337
+ result = f (states , x )
338
+ if isinstance (result , tuple ):
339
+ states , outputs = result
340
+ if outputs is not None :
341
+ outputs_list .append (outputs )
342
+ else :
343
+ states = result
344
+
345
+ if outputs_list :
346
+ if isinstance (outputs_list [0 ], tuple ):
347
+ # Multiple outputs case
348
+ outputs = tuple (
349
+ mx .stack ([out [i ] for out in outputs_list ])
350
+ for i in range (len (outputs_list [0 ]))
351
+ )
352
+ else :
353
+ # Single output case
354
+ outputs = mx .stack (outputs_list )
355
+
356
+ if reverse :
357
+ if isinstance (outputs , tuple ):
358
+ outputs = tuple (reverse_sequence (out ) for out in outputs )
359
+ else :
360
+ outputs = reverse_sequence (outputs )
361
+
362
+ return states , outputs
363
+
364
+ return states , None
0 commit comments