Skip to content

Commit 0ad4c78

Browse files
authored
mlx - rnn (#20786)
* ++ * rnn initial impl, all tests without convs passing
1 parent 9b75b86 commit 0ad4c78

File tree

3 files changed

+248
-9
lines changed

3 files changed

+248
-9
lines changed

keras/src/backend/mlx/linalg.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import jax.numpy as jnp
12
import mlx.core as mx
23

4+
from keras.src.backend.common import dtypes
35
from keras.src.backend.common import standardize_dtype
46
from keras.src.backend.mlx.core import convert_to_tensor
57

@@ -29,7 +31,10 @@ def solve_triangular(a, b, lower=False):
2931

3032

3133
def qr(x, mode="reduced"):
32-
return mx.linalg.qr(x)
34+
# TODO: Swap to mlx.linalg.qr when it supports non-square matrices
35+
x = jnp.array(x)
36+
output = jnp.linalg.qr(x, mode=mode)
37+
return mx.array(output[0]), mx.array(output[1])
3338

3439

3540
def svd(x, full_matrices=True, compute_uv=True):

keras/src/backend/mlx/numpy.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import builtins
2-
31
import mlx.core as mx
42

53
from keras.src.backend import config
@@ -409,14 +407,21 @@ def expm1(x):
409407
def flip(x, axis=None):
410408
x = convert_to_tensor(x)
411409
if axis is None:
412-
axis = tuple(range(x.ndim))
410+
indexer = tuple(slice(None, None, -1) for _ in range(x.ndim))
411+
return x[indexer]
413412
if isinstance(axis, int):
414413
axis = (axis,)
415-
indices = [slice(None)] * len(x.shape)
414+
indexer = [slice(None)] * x.ndim
416415
for ax in axis:
417-
indices[ax] = slice(None, None, -1)
418-
419-
return x[indices]
416+
if ax < 0:
417+
ax = x.ndim + ax
418+
if not 0 <= ax < x.ndim:
419+
raise ValueError(
420+
f"axis {ax} is out of bounds for array of dimension {x.ndim}"
421+
)
422+
indexer[ax] = slice(None, None, -1)
423+
424+
return x[tuple(indexer)]
420425

421426

422427
def floor(x):

keras/src/backend/mlx/rnn.py

Lines changed: 230 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
import contextlib
2+
3+
import mlx.core as mx
4+
5+
from keras.src import tree
6+
from keras.src.backend.common import stateless_scope
7+
8+
19
def rnn(
210
step_function,
311
inputs,
@@ -11,7 +19,228 @@ def rnn(
1119
zero_output_for_mask=False,
1220
return_all_outputs=True,
1321
):
14-
raise NotImplementedError("rnn not yet implemented in mlx")
22+
def swap_batch_timestep(input_t):
23+
# Swap the batch and timestep dim for the incoming tensor.
24+
axes = list(range(len(input_t.shape)))
25+
axes[0], axes[1] = 1, 0
26+
return mx.transpose(input_t, axes)
27+
28+
if not time_major:
29+
inputs = tree.map_structure(swap_batch_timestep, inputs)
30+
31+
flattened_inputs = tree.flatten(inputs)
32+
time_steps = flattened_inputs[0].shape[0]
33+
34+
if mask is not None:
35+
if mask.dtype != mx.bool_:
36+
mask = mask.astype(mx.bool_)
37+
if len(mask.shape) == 2:
38+
mask = mx.expand_dims(mask, axis=-1)
39+
if not time_major:
40+
mask = swap_batch_timestep(mask)
41+
42+
if constants is None:
43+
constants = []
44+
45+
def _expand_mask(mask_t, input_t, fixed_dim=1):
46+
if tree.is_nested(mask_t):
47+
raise ValueError(
48+
f"mask_t is expected to be tensor, but got {mask_t}"
49+
)
50+
if tree.is_nested(input_t):
51+
raise ValueError(
52+
f"input_t is expected to be tensor, but got {input_t}"
53+
)
54+
rank_diff = len(input_t.shape) - len(mask_t.shape)
55+
for _ in range(rank_diff):
56+
mask_t = mx.expand_dims(mask_t, axis=-1)
57+
multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])
58+
return mx.tile(mask_t, multiples)
59+
60+
if unroll:
61+
if not time_steps:
62+
raise ValueError("Unrolling requires a fixed number of timesteps.")
63+
states = tuple(initial_states)
64+
successive_states = []
65+
successive_outputs = []
66+
67+
# Process the input tensors. The input tensor need to be split on the
68+
# time_step dim, and reverse if go_backwards is True. In the case of
69+
# nested input, the input is flattened and then transformed
70+
# individually. The result of this will be a tuple of lists, each of
71+
# the item in tuple is list of the tensor with shape (batch, feature)
72+
def _process_single_input_t(input_t):
73+
input_t = unstack(input_t) # unstack for time_step dim
74+
if go_backwards:
75+
input_t.reverse()
76+
return input_t
77+
78+
if tree.is_nested(inputs):
79+
processed_input = tree.map_structure(
80+
_process_single_input_t, inputs
81+
)
82+
else:
83+
processed_input = (_process_single_input_t(inputs),)
84+
85+
def _get_input_tensor(time):
86+
inp = [t_[time] for t_ in processed_input]
87+
return tree.pack_sequence_as(inputs, inp)
88+
89+
if mask is not None:
90+
mask_list = unstack(mask)
91+
if go_backwards:
92+
mask_list.reverse()
93+
94+
for i in range(time_steps):
95+
inp = _get_input_tensor(i)
96+
mask_t = mask_list[i]
97+
output, new_states = step_function(
98+
inp, tuple(states) + tuple(constants)
99+
)
100+
tiled_mask_t = _expand_mask(mask_t, output)
101+
102+
if not successive_outputs:
103+
prev_output = mx.zeros_like(output)
104+
else:
105+
prev_output = successive_outputs[-1]
106+
107+
output = mx.where(tiled_mask_t, output, prev_output)
108+
109+
flat_states = tree.flatten(states)
110+
flat_new_states = tree.flatten(new_states)
111+
tiled_mask_t = tuple(
112+
_expand_mask(mask_t, s) for s in flat_states
113+
)
114+
flat_final_states = tuple(
115+
mx.where(m, s, ps)
116+
for m, s, ps in zip(
117+
tiled_mask_t, flat_new_states, flat_states
118+
)
119+
)
120+
states = tree.pack_sequence_as(states, flat_final_states)
121+
122+
if return_all_outputs:
123+
successive_outputs.append(output)
124+
successive_states.append(states)
125+
else:
126+
successive_outputs = [output]
127+
successive_states = [states]
128+
last_output = successive_outputs[-1]
129+
new_states = successive_states[-1]
130+
outputs = mx.stack(successive_outputs)
131+
132+
else: # mask is None
133+
for i in range(time_steps):
134+
inp = _get_input_tensor(i)
135+
output, states = step_function(
136+
inp, tuple(states) + tuple(constants)
137+
)
138+
if return_all_outputs:
139+
successive_outputs.append(output)
140+
successive_states.append(states)
141+
else:
142+
successive_outputs = [output]
143+
successive_states = [states]
144+
last_output = successive_outputs[-1]
145+
new_states = successive_states[-1]
146+
outputs = mx.stack(successive_outputs)
147+
148+
else: # Unroll == False
149+
if mask is not None:
150+
151+
def _step(states, current_input):
152+
current_input, current_mask = current_input
153+
is_masked = mx.all(
154+
mx.logical_not(current_mask), axis=-1, keepdims=True
155+
)
156+
157+
output_t, new_states = step_function(current_input, states)
158+
159+
if zero_output_for_mask:
160+
masked_outs = mx.where(
161+
is_masked, mx.zeros_like(output_t), output_t
162+
)
163+
else:
164+
# Assume the first state is the previous output.
165+
output_tm1 = states[0]
166+
masked_outs = mx.where(is_masked, output_tm1, output_t)
167+
168+
new_states = [
169+
mx.where(is_masked, s, ns)
170+
for s, ns in zip(states, new_states)
171+
]
172+
return (new_states, masked_outs)
173+
174+
scan_xs = (inputs, mask)
175+
176+
else:
177+
178+
def _step(states, current_input):
179+
output_t, new_states = step_function(current_input, states)
180+
return new_states, output_t
181+
182+
scan_xs = inputs
183+
if stateless_scope.in_stateless_scope():
184+
# Reuse the existing parent stateless scope.
185+
scope = contextlib.nullcontext()
186+
else:
187+
scope = stateless_scope.StatelessScope()
188+
with scope:
189+
new_states, outputs = mlx_scan(
190+
f=_step,
191+
init=initial_states,
192+
xs=scan_xs,
193+
reverse=go_backwards,
194+
mask=mask,
195+
)
196+
197+
if go_backwards:
198+
outputs = reverse_sequence(outputs)
199+
200+
last_output = outputs[-1]
201+
202+
if not time_major:
203+
outputs = tree.map_structure(swap_batch_timestep, outputs)
204+
205+
return last_output, outputs, new_states
206+
207+
208+
def reverse_sequence(xs):
209+
indices = mx.arange(xs.shape[0] - 1, -1, -1)
210+
return mx.take(xs, indices, axis=0)
211+
212+
213+
def unstack(x, axis=0):
214+
return [mx.take(x, i, axis=axis) for i in range(x.shape[axis])]
215+
216+
217+
def mlx_scan(f, init, xs, reverse=False, mask=None):
218+
states = init
219+
outputs = []
220+
221+
if mask is not None:
222+
x, mask = xs
223+
if reverse:
224+
x = reverse_sequence(x)
225+
mask = reverse_sequence(mask)
226+
227+
for each_x, each_mask in zip(x, mask):
228+
states, output = f(states, (each_x, each_mask))
229+
outputs.append(output)
230+
else:
231+
if reverse:
232+
xs = reverse_sequence(xs)
233+
234+
for x in xs:
235+
states, output = f(states, x)
236+
outputs.append(output)
237+
238+
outputs = mx.array(outputs)
239+
240+
if reverse:
241+
outputs = reverse_sequence(outputs)
242+
243+
return states, outputs
15244

16245

17246
def cudnn_ok(*args, **kwargs):

0 commit comments

Comments
 (0)