1
+ import contextlib
2
+
3
+ import mlx .core as mx
4
+
5
+ from keras .src import tree
6
+ from keras .src .backend .common import stateless_scope
7
+
8
+
1
9
def rnn (
2
10
step_function ,
3
11
inputs ,
@@ -11,7 +19,228 @@ def rnn(
11
19
zero_output_for_mask = False ,
12
20
return_all_outputs = True ,
13
21
):
14
- raise NotImplementedError ("rnn not yet implemented in mlx" )
22
+ def swap_batch_timestep (input_t ):
23
+ # Swap the batch and timestep dim for the incoming tensor.
24
+ axes = list (range (len (input_t .shape )))
25
+ axes [0 ], axes [1 ] = 1 , 0
26
+ return mx .transpose (input_t , axes )
27
+
28
+ if not time_major :
29
+ inputs = tree .map_structure (swap_batch_timestep , inputs )
30
+
31
+ flattened_inputs = tree .flatten (inputs )
32
+ time_steps = flattened_inputs [0 ].shape [0 ]
33
+
34
+ if mask is not None :
35
+ if mask .dtype != mx .bool_ :
36
+ mask = mask .astype (mx .bool_ )
37
+ if len (mask .shape ) == 2 :
38
+ mask = mx .expand_dims (mask , axis = - 1 )
39
+ if not time_major :
40
+ mask = swap_batch_timestep (mask )
41
+
42
+ if constants is None :
43
+ constants = []
44
+
45
+ def _expand_mask (mask_t , input_t , fixed_dim = 1 ):
46
+ if tree .is_nested (mask_t ):
47
+ raise ValueError (
48
+ f"mask_t is expected to be tensor, but got { mask_t } "
49
+ )
50
+ if tree .is_nested (input_t ):
51
+ raise ValueError (
52
+ f"input_t is expected to be tensor, but got { input_t } "
53
+ )
54
+ rank_diff = len (input_t .shape ) - len (mask_t .shape )
55
+ for _ in range (rank_diff ):
56
+ mask_t = mx .expand_dims (mask_t , axis = - 1 )
57
+ multiples = [1 ] * fixed_dim + list (input_t .shape [fixed_dim :])
58
+ return mx .tile (mask_t , multiples )
59
+
60
+ if unroll :
61
+ if not time_steps :
62
+ raise ValueError ("Unrolling requires a fixed number of timesteps." )
63
+ states = tuple (initial_states )
64
+ successive_states = []
65
+ successive_outputs = []
66
+
67
+ # Process the input tensors. The input tensor need to be split on the
68
+ # time_step dim, and reverse if go_backwards is True. In the case of
69
+ # nested input, the input is flattened and then transformed
70
+ # individually. The result of this will be a tuple of lists, each of
71
+ # the item in tuple is list of the tensor with shape (batch, feature)
72
+ def _process_single_input_t (input_t ):
73
+ input_t = unstack (input_t ) # unstack for time_step dim
74
+ if go_backwards :
75
+ input_t .reverse ()
76
+ return input_t
77
+
78
+ if tree .is_nested (inputs ):
79
+ processed_input = tree .map_structure (
80
+ _process_single_input_t , inputs
81
+ )
82
+ else :
83
+ processed_input = (_process_single_input_t (inputs ),)
84
+
85
+ def _get_input_tensor (time ):
86
+ inp = [t_ [time ] for t_ in processed_input ]
87
+ return tree .pack_sequence_as (inputs , inp )
88
+
89
+ if mask is not None :
90
+ mask_list = unstack (mask )
91
+ if go_backwards :
92
+ mask_list .reverse ()
93
+
94
+ for i in range (time_steps ):
95
+ inp = _get_input_tensor (i )
96
+ mask_t = mask_list [i ]
97
+ output , new_states = step_function (
98
+ inp , tuple (states ) + tuple (constants )
99
+ )
100
+ tiled_mask_t = _expand_mask (mask_t , output )
101
+
102
+ if not successive_outputs :
103
+ prev_output = mx .zeros_like (output )
104
+ else :
105
+ prev_output = successive_outputs [- 1 ]
106
+
107
+ output = mx .where (tiled_mask_t , output , prev_output )
108
+
109
+ flat_states = tree .flatten (states )
110
+ flat_new_states = tree .flatten (new_states )
111
+ tiled_mask_t = tuple (
112
+ _expand_mask (mask_t , s ) for s in flat_states
113
+ )
114
+ flat_final_states = tuple (
115
+ mx .where (m , s , ps )
116
+ for m , s , ps in zip (
117
+ tiled_mask_t , flat_new_states , flat_states
118
+ )
119
+ )
120
+ states = tree .pack_sequence_as (states , flat_final_states )
121
+
122
+ if return_all_outputs :
123
+ successive_outputs .append (output )
124
+ successive_states .append (states )
125
+ else :
126
+ successive_outputs = [output ]
127
+ successive_states = [states ]
128
+ last_output = successive_outputs [- 1 ]
129
+ new_states = successive_states [- 1 ]
130
+ outputs = mx .stack (successive_outputs )
131
+
132
+ else : # mask is None
133
+ for i in range (time_steps ):
134
+ inp = _get_input_tensor (i )
135
+ output , states = step_function (
136
+ inp , tuple (states ) + tuple (constants )
137
+ )
138
+ if return_all_outputs :
139
+ successive_outputs .append (output )
140
+ successive_states .append (states )
141
+ else :
142
+ successive_outputs = [output ]
143
+ successive_states = [states ]
144
+ last_output = successive_outputs [- 1 ]
145
+ new_states = successive_states [- 1 ]
146
+ outputs = mx .stack (successive_outputs )
147
+
148
+ else : # Unroll == False
149
+ if mask is not None :
150
+
151
+ def _step (states , current_input ):
152
+ current_input , current_mask = current_input
153
+ is_masked = mx .all (
154
+ mx .logical_not (current_mask ), axis = - 1 , keepdims = True
155
+ )
156
+
157
+ output_t , new_states = step_function (current_input , states )
158
+
159
+ if zero_output_for_mask :
160
+ masked_outs = mx .where (
161
+ is_masked , mx .zeros_like (output_t ), output_t
162
+ )
163
+ else :
164
+ # Assume the first state is the previous output.
165
+ output_tm1 = states [0 ]
166
+ masked_outs = mx .where (is_masked , output_tm1 , output_t )
167
+
168
+ new_states = [
169
+ mx .where (is_masked , s , ns )
170
+ for s , ns in zip (states , new_states )
171
+ ]
172
+ return (new_states , masked_outs )
173
+
174
+ scan_xs = (inputs , mask )
175
+
176
+ else :
177
+
178
+ def _step (states , current_input ):
179
+ output_t , new_states = step_function (current_input , states )
180
+ return new_states , output_t
181
+
182
+ scan_xs = inputs
183
+ if stateless_scope .in_stateless_scope ():
184
+ # Reuse the existing parent stateless scope.
185
+ scope = contextlib .nullcontext ()
186
+ else :
187
+ scope = stateless_scope .StatelessScope ()
188
+ with scope :
189
+ new_states , outputs = mlx_scan (
190
+ f = _step ,
191
+ init = initial_states ,
192
+ xs = scan_xs ,
193
+ reverse = go_backwards ,
194
+ mask = mask ,
195
+ )
196
+
197
+ if go_backwards :
198
+ outputs = reverse_sequence (outputs )
199
+
200
+ last_output = outputs [- 1 ]
201
+
202
+ if not time_major :
203
+ outputs = tree .map_structure (swap_batch_timestep , outputs )
204
+
205
+ return last_output , outputs , new_states
206
+
207
+
208
+ def reverse_sequence (xs ):
209
+ indices = mx .arange (xs .shape [0 ] - 1 , - 1 , - 1 )
210
+ return mx .take (xs , indices , axis = 0 )
211
+
212
+
213
+ def unstack (x , axis = 0 ):
214
+ return [mx .take (x , i , axis = axis ) for i in range (x .shape [axis ])]
215
+
216
+
217
+ def mlx_scan (f , init , xs , reverse = False , mask = None ):
218
+ states = init
219
+ outputs = []
220
+
221
+ if mask is not None :
222
+ x , mask = xs
223
+ if reverse :
224
+ x = reverse_sequence (x )
225
+ mask = reverse_sequence (mask )
226
+
227
+ for each_x , each_mask in zip (x , mask ):
228
+ states , output = f (states , (each_x , each_mask ))
229
+ outputs .append (output )
230
+ else :
231
+ if reverse :
232
+ xs = reverse_sequence (xs )
233
+
234
+ for x in xs :
235
+ states , output = f (states , x )
236
+ outputs .append (output )
237
+
238
+ outputs = mx .array (outputs )
239
+
240
+ if reverse :
241
+ outputs = reverse_sequence (outputs )
242
+
243
+ return states , outputs
15
244
16
245
17
246
def cudnn_ok (* args , ** kwargs ):
0 commit comments