|
11 | 11 | from PIL import Image |
12 | 12 | from PIL import ImageEnhance |
13 | 13 | from pillow_heif import register_heif_opener |
| 14 | +from typing import Union |
14 | 15 |
|
15 | 16 | import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder |
16 | 17 | import DECIMER.Transformer_decoder as Transformer_decoder |
@@ -95,26 +96,68 @@ def HEIF_to_pillow(image_path: str): |
95 | 96 | return heif_file |
96 | 97 |
|
97 | 98 |
|
98 | | -def remove_transparent(image_path: str): |
| 99 | +def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image: |
99 | 100 | """ |
100 | | - Removes the transparent layer from a PNG image with an alpha channel |
101 | | - Args: image_path (str): path of input image |
102 | | - Returns: PIL.Image |
| 101 | + Removes the transparent layer from a PNG image with an alpha channel. |
| 102 | +
|
| 103 | + Args: |
| 104 | + image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + PIL.Image.Image: The image with transparency removed. |
103 | 108 | """ |
104 | | - try: |
105 | | - png = Image.open(image_path).convert("RGBA") |
106 | | - except Exception as e: |
107 | | - if type(e).__name__ == "UnidentifiedImageError": |
108 | | - png = HEIF_to_pillow(image_path) |
109 | | - else: |
110 | | - print(e) |
111 | | - raise Exception |
| 109 | + def process_image(png: Image.Image) -> Image.Image: |
| 110 | + """ |
| 111 | + Helper function to remove transparency from a single image. |
| 112 | +
|
| 113 | + Args: |
| 114 | + png (PIL.Image.Image): The input PIL image with transparency. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + PIL.Image.Image: The image with transparency removed. |
| 118 | + """ |
| 119 | + background = Image.new("RGBA", png.size, (255, 255, 255)) |
| 120 | + alpha_composite = Image.alpha_composite(background, png) |
| 121 | + return alpha_composite |
112 | 122 |
|
113 | | - background = Image.new("RGBA", png.size, (255, 255, 255)) |
| 123 | + def handle_image_path(image_path: str) -> Image.Image: |
| 124 | + """ |
| 125 | + Helper function to handle image paths. |
| 126 | +
|
| 127 | + Args: |
| 128 | + image_path (str): The path to the input image. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + PIL.Image.Image: The image with transparency removed. |
| 132 | + """ |
| 133 | + try: |
| 134 | + png = Image.open(image_path).convert("RGBA") |
| 135 | + except Exception as e: |
| 136 | + if type(e).__name__ == "UnidentifiedImageError": |
| 137 | + png = HEIF_to_pillow(image_path) |
| 138 | + else: |
| 139 | + print(e) |
| 140 | + raise Exception |
| 141 | + return process_image(png) |
| 142 | + |
| 143 | + def handle_numpy_array(array: np.ndarray) -> Image.Image: |
| 144 | + """ |
| 145 | + Helper function to handle a numpy array. |
| 146 | +
|
| 147 | + Args: |
| 148 | + array (np.ndarray): The numpy array representing the image. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + PIL.Image.Image: The image with transparency removed. |
| 152 | + """ |
| 153 | + png = Image.fromarray(array).convert("RGBA") |
| 154 | + return process_image(png) |
114 | 155 |
|
115 | | - alpha_composite = Image.alpha_composite(background, png) |
| 156 | + # Check if input is a numpy array |
| 157 | + if isinstance(image, np.ndarray): |
| 158 | + return handle_numpy_array(array=image) |
116 | 159 |
|
117 | | - return alpha_composite |
| 160 | + return handle_image_path(image_path=image) |
118 | 161 |
|
119 | 162 |
|
120 | 163 | def get_bnw_image(image): |
@@ -185,12 +228,12 @@ def increase_brightness(image): |
185 | 228 | return image |
186 | 229 |
|
187 | 230 |
|
188 | | -def decode_image(image_path: str): |
| 231 | +def decode_image(image_path: Union[str, np.ndarray]): |
189 | 232 | """Loads an image and preprocesses the input image in several steps to get |
190 | 233 | the image ready for DECIMER input. |
191 | 234 |
|
192 | 235 | Args: |
193 | | - image_path (str): path of input image |
| 236 | + image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image. |
194 | 237 |
|
195 | 238 | Returns: |
196 | 239 | Processed image |
@@ -237,7 +280,7 @@ def initialize_encoder_config( |
237 | 280 | backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder |
238 | 281 | image_shape (int): Shape of the input image |
239 | 282 | do_permute (bool, optional): . Defaults to False. |
240 | | - pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None. |
| 283 | + pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None. |
241 | 284 | """ |
242 | 285 | self.encoder_config = dict( |
243 | 286 | image_embedding_dim=image_embedding_dim, |
|
0 commit comments