Skip to content

Commit 5377ea2

Browse files
authored
Merge pull request Kohulan#106 from Kohulan/development
feat: new input format as numpy array representing the image
2 parents b09ac80 + 7034d6f commit 5377ea2

File tree

3 files changed

+200
-106
lines changed

3 files changed

+200
-106
lines changed

DECIMER/config.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from PIL import Image
1212
from PIL import ImageEnhance
1313
from pillow_heif import register_heif_opener
14+
from typing import Union
1415

1516
import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
1617
import DECIMER.Transformer_decoder as Transformer_decoder
@@ -95,26 +96,68 @@ def HEIF_to_pillow(image_path: str):
9596
return heif_file
9697

9798

98-
def remove_transparent(image_path: str):
99+
def remove_transparent(image: Union[str, np.ndarray]) -> Image.Image:
99100
"""
100-
Removes the transparent layer from a PNG image with an alpha channel
101-
Args: image_path (str): path of input image
102-
Returns: PIL.Image
101+
Removes the transparent layer from a PNG image with an alpha channel.
102+
103+
Args:
104+
image (Union[str, np.ndarray]): Path of the input image or a numpy array representing the image.
105+
106+
Returns:
107+
PIL.Image.Image: The image with transparency removed.
103108
"""
104-
try:
105-
png = Image.open(image_path).convert("RGBA")
106-
except Exception as e:
107-
if type(e).__name__ == "UnidentifiedImageError":
108-
png = HEIF_to_pillow(image_path)
109-
else:
110-
print(e)
111-
raise Exception
109+
def process_image(png: Image.Image) -> Image.Image:
110+
"""
111+
Helper function to remove transparency from a single image.
112+
113+
Args:
114+
png (PIL.Image.Image): The input PIL image with transparency.
115+
116+
Returns:
117+
PIL.Image.Image: The image with transparency removed.
118+
"""
119+
background = Image.new("RGBA", png.size, (255, 255, 255))
120+
alpha_composite = Image.alpha_composite(background, png)
121+
return alpha_composite
112122

113-
background = Image.new("RGBA", png.size, (255, 255, 255))
123+
def handle_image_path(image_path: str) -> Image.Image:
124+
"""
125+
Helper function to handle image paths.
126+
127+
Args:
128+
image_path (str): The path to the input image.
129+
130+
Returns:
131+
PIL.Image.Image: The image with transparency removed.
132+
"""
133+
try:
134+
png = Image.open(image_path).convert("RGBA")
135+
except Exception as e:
136+
if type(e).__name__ == "UnidentifiedImageError":
137+
png = HEIF_to_pillow(image_path)
138+
else:
139+
print(e)
140+
raise Exception
141+
return process_image(png)
142+
143+
def handle_numpy_array(array: np.ndarray) -> Image.Image:
144+
"""
145+
Helper function to handle a numpy array.
146+
147+
Args:
148+
array (np.ndarray): The numpy array representing the image.
149+
150+
Returns:
151+
PIL.Image.Image: The image with transparency removed.
152+
"""
153+
png = Image.fromarray(array).convert("RGBA")
154+
return process_image(png)
114155

115-
alpha_composite = Image.alpha_composite(background, png)
156+
# Check if input is a numpy array
157+
if isinstance(image, np.ndarray):
158+
return handle_numpy_array(array=image)
116159

117-
return alpha_composite
160+
return handle_image_path(image_path=image)
118161

119162

120163
def get_bnw_image(image):
@@ -185,12 +228,12 @@ def increase_brightness(image):
185228
return image
186229

187230

188-
def decode_image(image_path: str):
231+
def decode_image(image_path: Union[str, np.ndarray]):
189232
"""Loads an image and preprocesses the input image in several steps to get
190233
the image ready for DECIMER input.
191234
192235
Args:
193-
image_path (str): path of input image
236+
image_path (Union[str, np.ndarray]): path of input image or numpy array representing the image.
194237
195238
Returns:
196239
Processed image
@@ -237,7 +280,7 @@ def initialize_encoder_config(
237280
backbone_fn (method): Calls Efficient-Net V2 as backbone for encoder
238281
image_shape (int): Shape of the input image
239282
do_permute (bool, optional): . Defaults to False.
240-
pretrained_weights (keras weights, optional): Use pretrainined efficient net weights or not. Defaults to None.
283+
pretrained_weights (keras weights, optional): Use pretrained efficient net weights or not. Defaults to None.
241284
"""
242285
self.encoder_config = dict(
243286
image_embedding_dim=image_embedding_dim,

DECIMER/decimer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66
from typing import Tuple
77

8+
import numpy as np
89
import pystow
910
import tensorflow as tf
1011

@@ -122,19 +123,19 @@ def detokenize_output_add_confidence(
122123

123124

124125
def predict_SMILES(
125-
image_path: str, confidence: bool = False, hand_drawn: bool = False
126+
image_input: [str, np.ndarray], confidence: bool = False, hand_drawn: bool = False
126127
) -> str:
127128
"""Predicts SMILES representation of a molecule depicted in the given image.
128129
129130
Args:
130-
image_path (str): Path of chemical structure depiction image
131-
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction
132-
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn
131+
image_input (str or np.ndarray): Path of chemical structure depiction image or a numpy array representing the image.
132+
confidence (bool): Flag to indicate whether to return confidence values along with SMILES prediction.
133+
hand_drawn (bool): Flag to indicate whether the molecule in the image is hand-drawn.
133134
134135
Returns:
135-
str: SMILES representation of the molecule in the input image, optionally with confidence values
136+
str: SMILES representation of the molecule in the input image, optionally with confidence values.
136137
"""
137-
chemical_structure = config.decode_image(image_path)
138+
chemical_structure = config.decode_image(image_input)
138139

139140
model = DECIMER_Hand_drawn if hand_drawn else DECIMER_V2
140141
predicted_tokens, confidence_values = model(tf.constant(chemical_structure))

0 commit comments

Comments
 (0)