-
Notifications
You must be signed in to change notification settings - Fork 8
GITC-7208: Fixing Colormap Rasterization issues #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
867b2f8
0491fd8
bc0a938
bab4146
14b050e
424c417
e22ed3c
9e2c4c2
7ad99be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
2.3.0 | ||
2.4.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,13 +11,11 @@ | |
from affine import dumpsw | ||
from harmony_service_lib.message import Message as HarmonyMessage | ||
from harmony_service_lib.message import Source as HarmonySource | ||
from matplotlib.cm import ScalarMappable | ||
from matplotlib.colors import Normalize | ||
from numpy import ndarray, uint8 | ||
from osgeo_utils.auxiliary.color_palette import ColorPalette | ||
from PIL import Image | ||
from rasterio.io import DatasetReader | ||
from rasterio.plot import reshape_as_image, reshape_as_raster | ||
from rasterio.warp import Resampling, reproject | ||
from rioxarray import open_rasterio | ||
from xarray import DataArray | ||
|
@@ -28,10 +26,12 @@ | |
NODATA_RGBA, | ||
OPAQUE, | ||
TRANSPARENT, | ||
ColorMap, | ||
all_black_color_map, | ||
colormap_from_colors, | ||
get_color_palette, | ||
greyscale_colormap, | ||
palette_from_remote_colortable, | ||
remove_alpha, | ||
) | ||
from hybig.exceptions import HyBIGError | ||
from hybig.sizes import ( | ||
|
@@ -171,18 +171,19 @@ def create_browse_imagery( | |
color_palette = get_color_palette( | ||
in_dataset, source, item_color_palette | ||
) | ||
raster = convert_singleband_to_raster(rio_in_array, color_palette) | ||
raster, color_map = convert_singleband_to_raster( | ||
rio_in_array, color_palette | ||
) | ||
elif rio_in_array.rio.count in (3, 4): | ||
raster = convert_mulitband_to_raster(rio_in_array) | ||
color_map = None | ||
if output_driver == 'JPEG': | ||
raster = raster[0:3, :, :] | ||
else: | ||
raise HyBIGError( | ||
f'incorrect number of bands for image: {rio_in_array.rio.count}' | ||
) | ||
|
||
raster, color_map = standardize_raster_for_writing( | ||
raster, output_driver, rio_in_array.rio.count | ||
) | ||
|
||
grid_parameters = get_target_grid_parameters(message, rio_in_array) | ||
grid_parameter_list, tile_locators = create_tiled_output_parameters( | ||
grid_parameters | ||
|
@@ -283,69 +284,65 @@ def original_dtype(data_array: DataArray) -> str | None: | |
def convert_singleband_to_raster( | ||
data_array: DataArray, | ||
color_palette: ColorPalette | None = None, | ||
) -> ndarray: | ||
"""Convert input dataset to a 4 band raster image. | ||
) -> tuple[ndarray, ColorMap]: | ||
"""Convert input dataset to a 1-band palettized image with colormap. | ||
|
||
Use a palette if provided otherwise return a greyscale image. | ||
Uses a palette if provided otherwise returns a greyscale image. | ||
""" | ||
if color_palette is None: | ||
return convert_gray_1band_to_raster(data_array) | ||
return convert_paletted_1band_to_raster(data_array, color_palette) | ||
return scale_grey_1band(data_array) | ||
return scale_paletted_1band(data_array, color_palette) | ||
|
||
|
||
def convert_gray_1band_to_raster(data_array: DataArray) -> ndarray: | ||
"""Convert a 1-band raster without a color association.""" | ||
def scale_grey_1band(data_array: DataArray) -> tuple[ndarray, ColorMap]: | ||
"""Normalize input array and return scaled data with greyscale ColorMap.""" | ||
band = data_array[0, :, :] | ||
cmap = matplotlib.colormaps['Greys_r'] | ||
cmap.set_bad(NODATA_RGBA) | ||
norm = Normalize(vmin=np.nanmin(band), vmax=np.nanmax(band)) | ||
scalar_map = ScalarMappable(cmap=cmap, norm=norm) | ||
|
||
rgba_image = np.zeros((*band.shape, 4), dtype='uint8') | ||
for row_no in range(band.shape[0]): | ||
rgba_image_slice = scalar_map.to_rgba(band[row_no, :], bytes=True) | ||
rgba_image[row_no, :, :] = rgba_image_slice | ||
# Scale input data from 0 to 254 | ||
normalized_data = norm(band) * 254.0 | ||
|
||
return reshape_as_raster(rgba_image) | ||
# Set any missing to missing | ||
normalized_data[np.isnan(normalized_data)] = NODATA_IDX | ||
|
||
grey_colormap = greyscale_colormap() | ||
raster_data = np.expand_dims(np.round(normalized_data).data, 0) | ||
return np.array(raster_data, dtype='uint8'), grey_colormap | ||
|
||
def convert_paletted_1band_to_raster( | ||
|
||
def scale_paletted_1band( | ||
data_array: DataArray, palette: ColorPalette | ||
) -> ndarray: | ||
"""Convert a 1 band image with palette into a rgba raster image.""" | ||
) -> tuple[ndarray, ColorMap]: | ||
"""Scale a 1-band image with palette into modified image and associated color_map. | ||
|
||
Use the palette's levels and values, transform the input data_array into | ||
the correct levels indexed from 0-255 return the scaled array along side of | ||
a colormap corresponding to the new levels. | ||
""" | ||
band = data_array[0, :, :] | ||
levels = list(palette.pal.keys()) | ||
colors = [ | ||
palette.color_to_color_entry(value, with_alpha=True) | ||
for value in palette.pal.values() | ||
] | ||
scaled_colors = [ | ||
(r / 255.0, g / 255.0, b / 255.0, a / 255.0) for r, g, b, a in colors | ||
] | ||
|
||
cmap, norm = matplotlib.colors.from_levels_and_colors( | ||
levels, scaled_colors, extend='max' | ||
) | ||
norm = matplotlib.colors.BoundaryNorm(levels, len(levels) - 1) | ||
|
||
# handle palette no data value | ||
nodata_color = (0, 0, 0, 0) | ||
if palette.ndv is not None: | ||
nodata_colors = palette.color_to_color_entry(palette.ndv, with_alpha=True) | ||
cmap.set_bad( | ||
( | ||
nodata_colors[0] / 255.0, | ||
nodata_colors[1] / 255.0, | ||
nodata_colors[2] / 255.0, | ||
nodata_colors[3] / 255.0, | ||
) | ||
) | ||
nodata_color = palette.color_to_color_entry(palette.ndv, with_alpha=True) | ||
|
||
scalar_map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) | ||
rgba_image = np.zeros((*band.shape, 4), dtype='uint8') | ||
for row_no in range(band.shape[0]): | ||
rgba_image[row_no, :, :] = scalar_map.to_rgba( | ||
np.ma.masked_invalid(band[row_no, :]), bytes=True | ||
) | ||
return reshape_as_raster(rgba_image) | ||
colors = [*colors, nodata_color] | ||
|
||
scaled_band = norm(band) | ||
|
||
# Set underflow and nan values to nodata | ||
scaled_band[scaled_band == -1] = len(colors) - 1 | ||
scaled_band[np.isnan(band)] = len(colors) - 1 | ||
|
||
color_map = colormap_from_colors(colors) | ||
raster_data = np.expand_dims(scaled_band.data, 0) | ||
return np.array(raster_data, dtype='uint8'), color_map | ||
|
||
|
||
def image_driver(mime: str) -> str: | ||
|
@@ -355,81 +352,6 @@ def image_driver(mime: str) -> str: | |
return 'PNG' | ||
|
||
|
||
def standardize_raster_for_writing( | ||
raster: ndarray, | ||
driver: str, | ||
band_count: int, | ||
) -> tuple[ndarray, dict | None]: | ||
"""Standardize raster data for writing to browse image. | ||
|
||
Args: | ||
raster: Input raster data array | ||
driver: Output image format ('JPEG' or 'PNG') | ||
band_count: Number of bands in original input data | ||
|
||
The function handles two special cases: | ||
- JPEG output with 4-band data -> Drop alpha channel and return 3-band RGB | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we dropping handling of this JPEG special case? I think we need to support it for products like PACE. |
||
- PNG output with single-band data -> Convert to paletted format | ||
|
||
Returns: | ||
tuple: (prepared_raster, color_map) where: | ||
- prepared_raster is the processed ndarray | ||
- color_map is either None or a dict mapping palette indices to RGBA values | ||
|
||
|
||
""" | ||
if driver == 'JPEG' and raster.shape[0] == 4: | ||
return raster[0:3, :, :], None | ||
|
||
if driver == 'PNG' and band_count == 1: | ||
# Only palettize single band input data that has been converted to an | ||
# RGBA raster. | ||
return palettize_raster(raster) | ||
|
||
return raster, None | ||
|
||
|
||
def palettize_raster(raster: ndarray) -> tuple[ndarray, dict]: | ||
"""Convert an RGB or RGBA image into a 1band image and palette. | ||
|
||
Converts a 3 or 4 band np raster into a PIL image. | ||
Quantizes the image into a 1band raster with palette | ||
|
||
Transparency is handled by first removing the Alpha layer and creating | ||
quantized raster from just the RGB layers. Next the Alpha layer values are | ||
treated as either transparent or opaque and any transparent values are | ||
written to the final raster as 254 and add the mapped RGBA value to the | ||
color palette. | ||
""" | ||
# reserves index 255 for transparent and off grid fill values | ||
# 0 to 254 | ||
max_colors = 255 | ||
rgb_raster, alpha = remove_alpha(raster) | ||
|
||
multiband_image = Image.fromarray(reshape_as_image(rgb_raster)) | ||
quantized_image = multiband_image.quantize(colors=max_colors) | ||
|
||
color_map = get_color_map_from_image(quantized_image) | ||
|
||
quantized_array, color_map = add_alpha(alpha, np.array(quantized_image), color_map) | ||
|
||
one_band_raster = np.expand_dims(quantized_array, 0) | ||
return one_band_raster, color_map | ||
|
||
|
||
def add_alpha( | ||
alpha: ndarray | None, quantized_array: ndarray, color_map: dict | ||
) -> tuple[ndarray, dict]: | ||
"""If the input data had alpha values, manually set the quantized_image | ||
index to the transparent index in those places. | ||
""" | ||
if alpha is not None and np.any(alpha != OPAQUE): | ||
# Set any alpha to the transparent index value | ||
quantized_array = np.where(alpha != OPAQUE, NODATA_IDX, quantized_array) | ||
color_map[NODATA_IDX] = NODATA_RGBA | ||
return quantized_array, color_map | ||
|
||
|
||
def get_color_map_from_image(image: Image) -> dict: | ||
"""Get a writable color map | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
import numpy as np | ||
import requests | ||
from harmony_service_lib.message import Source as HarmonySource | ||
from numpy import uint8 | ||
from osgeo_utils.auxiliary.color_palette import ColorPalette | ||
from pystac import Item | ||
from rasterio.io import DatasetReader | ||
|
@@ -17,16 +18,18 @@ | |
HyBIGNoColorInformation, | ||
) | ||
|
||
ColorMap = dict[uint8, tuple[uint8, uint8, uint8, uint8]] | ||
|
||
# Constants for output PNG images | ||
# Applied to transparent pixels where alpha < 255 | ||
TRANSPARENT = np.uint8(0) | ||
OPAQUE = np.uint8(255) | ||
TRANSPARENT = uint8(0) | ||
OPAQUE = uint8(255) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unrelated to this pr but the typing below for remove_alpha is incorrect. def remove_alpha(raster: np.ndarray) -> tuple[np.ndarray, np.ndarray, None]: should be: def remove_alpha(raster: np.ndarray) -> tuple[np.ndarray, np.ndarray | None]: |
||
# Applied to off grid areas during reprojection | ||
NODATA_RGBA = (0, 0, 0, 0) | ||
NODATA_IDX = 255 | ||
|
||
|
||
def remove_alpha(raster: np.ndarray) -> tuple[np.ndarray, np.ndarray, None]: | ||
def remove_alpha(raster: np.ndarray) -> tuple[np.ndarray, np.ndarray | None]: | ||
"""Remove alpha layer when it exists.""" | ||
if raster.shape[0] == 4: | ||
return raster[0:3, :, :], raster[3, :, :] | ||
|
@@ -87,7 +90,16 @@ def get_color_palette( | |
return get_remote_palette_from_source(source) | ||
except HyBIGNoColorInformation: | ||
try: | ||
return convert_colormap_to_palette(dataset.colormap(1)) | ||
ds_cmap = dataset.colormap(1) | ||
# very defensive since this function is not documented in rasterio | ||
ndv_tuple: tuple[float, ...] = dataset.get_nodatavals() | ||
if ndv_tuple is not None and len(ndv_tuple) > 0: | ||
# this service only supports one ndv, so just use the first one | ||
# (usually the only one) | ||
ds_cmap['nv'] = ds_cmap[ndv_tuple[0]] | ||
# then remove the value associated with the ndv key | ||
ds_cmap.pop(ndv_tuple[0]) | ||
return convert_colormap_to_palette(ds_cmap) | ||
Comment on lines
+95
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like you know more about where color tables and nodata values might be, but if there was a test for it, it might be easier for me to understand. |
||
except ValueError: | ||
return None | ||
|
||
|
@@ -120,11 +132,28 @@ def get_remote_palette_from_source(source: HarmonySource) -> dict: | |
raise HyBIGNoColorInformation('No color in source') from exc | ||
|
||
|
||
def all_black_color_map(): | ||
def all_black_color_map() -> ColorMap: | ||
"""Return a full length rgba color map with all black values.""" | ||
return {idx: (0, 0, 0, 255) for idx in range(256)} | ||
|
||
|
||
def colormap_from_colors( | ||
colors: list[tuple[uint8, uint8, uint8, uint8]], | ||
) -> ColorMap: | ||
color_map = {} | ||
for idx, rgba in enumerate(colors): | ||
color_map[idx] = rgba | ||
return color_map | ||
|
||
|
||
def greyscale_colormap() -> ColorMap: | ||
color_map = {} | ||
for idx in range(255): | ||
color_map[idx] = (idx, idx, idx, 255) | ||
color_map[NODATA_IDX] = NODATA_RGBA | ||
return color_map | ||
|
||
|
||
def convert_colormap_to_palette(colormap: dict) -> ColorPalette: | ||
"""Convert a GeoTIFF palette to GDAL ColorPalette. | ||
|
||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should add a comment to explain why this uses 254 and not 255.