Skip to content

Commit bf15326

Browse files
authored
add rgb_to_grayscale (#19609)
1 parent 8b7997d commit bf15326

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

keras/src/backend/mlx/image.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,30 @@
88
from keras.src.backend.mlx.core import to_mlx_dtype
99

1010

11+
def rgb_to_grayscale(image, data_format="channels_last"):
12+
image = convert_to_tensor(image)
13+
if data_format == "channels_first":
14+
if len(image.shape) == 4:
15+
image = mx.transpose(image, (0, 2, 3, 1))
16+
elif len(image.shape) == 3:
17+
image = mx.transpose(image, (1, 2, 0))
18+
else:
19+
raise ValueError(
20+
"Invalid input rank: expected rank 3 (single image) "
21+
"or rank 4 (batch of images). Received input with shape: "
22+
f"image.shape={image.shape}"
23+
)
24+
red, green, blue = image[..., 0], image[..., 1], image[..., 2]
25+
grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue
26+
grayscale_image = mx.expand_dims(grayscale_image, axis=-1)
27+
if data_format == "channels_first":
28+
if len(image.shape) == 4:
29+
grayscale_image = mx.transpose(grayscale_image, (0, 3, 1, 2))
30+
elif len(image.shape) == 3:
31+
grayscale_image = mx.transpose(grayscale_image, (2, 0, 1))
32+
return mx.array(grayscale_image)
33+
34+
1135
def _mirror_index_fixer(index, size):
1236
s = size - 1 # Half-wavelength of triangular wave
1337
# Scaled, integer-valued version of the triangular wave |x - round(x)|

0 commit comments

Comments
 (0)