|
8 | 8 | from keras.src.backend.mlx.core import to_mlx_dtype
|
9 | 9 |
|
10 | 10 |
|
| 11 | +def rgb_to_grayscale(image, data_format="channels_last"): |
| 12 | + image = convert_to_tensor(image) |
| 13 | + if data_format == "channels_first": |
| 14 | + if len(image.shape) == 4: |
| 15 | + image = mx.transpose(image, (0, 2, 3, 1)) |
| 16 | + elif len(image.shape) == 3: |
| 17 | + image = mx.transpose(image, (1, 2, 0)) |
| 18 | + else: |
| 19 | + raise ValueError( |
| 20 | + "Invalid input rank: expected rank 3 (single image) " |
| 21 | + "or rank 4 (batch of images). Received input with shape: " |
| 22 | + f"image.shape={image.shape}" |
| 23 | + ) |
| 24 | + red, green, blue = image[..., 0], image[..., 1], image[..., 2] |
| 25 | + grayscale_image = 0.2989 * red + 0.5870 * green + 0.1140 * blue |
| 26 | + grayscale_image = mx.expand_dims(grayscale_image, axis=-1) |
| 27 | + if data_format == "channels_first": |
| 28 | + if len(image.shape) == 4: |
| 29 | + grayscale_image = mx.transpose(grayscale_image, (0, 3, 1, 2)) |
| 30 | + elif len(image.shape) == 3: |
| 31 | + grayscale_image = mx.transpose(grayscale_image, (2, 0, 1)) |
| 32 | + return mx.array(grayscale_image) |
| 33 | + |
| 34 | + |
11 | 35 | def _mirror_index_fixer(index, size):
|
12 | 36 | s = size - 1 # Half-wavelength of triangular wave
|
13 | 37 | # Scaled, integer-valued version of the triangular wave |x - round(x)|
|
|
0 commit comments