Skip to content

Commit bc2c6a3

Browse files
authored
impl and clean remaining functions in mlx.nn.py (#20811)
1 parent e2e2288 commit bc2c6a3

File tree

4 files changed

+511
-59
lines changed

4 files changed

+511
-59
lines changed

keras/src/backend/mlx/core.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,57 @@ def stop_gradient(variable):
308308
def unstack(x, num=None, axis=0):
309309
y = x.split(num or x.shape[axis], axis=axis)
310310
return [yi.squeeze(axis) for yi in y]
311+
312+
313+
def reverse_sequence(xs):
314+
indices = mx.arange(xs.shape[0] - 1, -1, -1)
315+
return mx.take(xs, indices, axis=0)
316+
317+
318+
def scan(f, init, xs, reverse=False, mask=None):
319+
states = init
320+
outputs_list = []
321+
322+
if mask is not None:
323+
x, mask = xs
324+
if reverse:
325+
x = reverse_sequence(x)
326+
mask = reverse_sequence(mask)
327+
iterator = zip(x, mask)
328+
else:
329+
if reverse:
330+
if isinstance(xs, tuple):
331+
xs = tuple(reverse_sequence(x) for x in xs)
332+
else:
333+
xs = reverse_sequence(xs)
334+
iterator = zip(*xs) if isinstance(xs, tuple) else xs
335+
336+
for x in iterator:
337+
result = f(states, x)
338+
if isinstance(result, tuple):
339+
states, outputs = result
340+
if outputs is not None:
341+
outputs_list.append(outputs)
342+
else:
343+
states = result
344+
345+
if outputs_list:
346+
if isinstance(outputs_list[0], tuple):
347+
# Multiple outputs case
348+
outputs = tuple(
349+
mx.stack([out[i] for out in outputs_list])
350+
for i in range(len(outputs_list[0]))
351+
)
352+
else:
353+
# Single output case
354+
outputs = mx.stack(outputs_list)
355+
356+
if reverse:
357+
if isinstance(outputs, tuple):
358+
outputs = tuple(reverse_sequence(out) for out in outputs)
359+
else:
360+
outputs = reverse_sequence(outputs)
361+
362+
return states, outputs
363+
364+
return states, None

keras/src/backend/mlx/linalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def norm(x, ord=None, axis=None, keepdims=False):
5252
if "int" in dtype or dtype == "bool":
5353
dtype = dtypes.result_type(x.dtype, "float32")
5454
x = convert_to_tensor(x, dtype=dtype)
55-
return mx.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
55+
# TODO: swap to mlx.linalg.norm when it support singular value norms
56+
x = jnp.array(x)
57+
output = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
58+
return mx.array(output)
5659

5760

5861
def inv(a):

0 commit comments

Comments
 (0)