Skip to content

Commit ba76733

Browse files
authored
Merge pull request Kohulan#69 from Kohulan/development
add token-level confidence scores
2 parents d572ded + 40cffed commit ba76733

File tree

7 files changed

+208
-32
lines changed

7 files changed

+208
-32
lines changed

DECIMER/Predictor_usingCheckpoints.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import os
22
import sys
33
import tensorflow as tf
4-
4+
from typing import List, Tuple
55
import pickle
66
import pystow
7-
from selfies import decoder
87
import Transformer_decoder
98

109
if int(tf.__version__.split(".")[1]) <= 10:
@@ -116,6 +115,9 @@ def main():
116115
else:
117116
SMILES = predict_SMILES(sys.argv[1])
118117
print(SMILES)
118+
SMILES_with_confidence = predict_SMILES_with_confidence(sys.argv[1])
119+
for tup in SMILES_with_confidence:
120+
print(tup)
119121

120122

121123
class DECIMER_Predictor(tf.Module):
@@ -126,6 +128,16 @@ def __init__(self, encoder, tokenizer, transformer, max_length):
126128
self.max_length = max_length
127129

128130
def __call__(self, Decoded_image):
131+
"""
132+
Run the DECIMER predictor model when called.
133+
Usage of predict_SMILES or predict_SMILES_with_confidence is recommended instead
134+
135+
Args:
136+
Decoded_image (_type_): output of config.decode_image
137+
138+
Returns:
139+
Tuple[tf.Tensor, tf.Tensor]: predicted tokens, confidence values
140+
"""
129141
assert isinstance(Decoded_image, tf.Tensor)
130142

131143
_image_batch = tf.expand_dims(Decoded_image, 0)
@@ -140,6 +152,7 @@ def __call__(self, Decoded_image):
140152

141153
output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
142154
output_array = output_array.write(0, start_token)
155+
confidence_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
143156

144157
for t in tf.range(max_length):
145158
output = tf.transpose(output_array.stack())
@@ -154,31 +167,92 @@ def __call__(self, Decoded_image):
154167
predictions = prediction_batch[:, -1:, :] # (batch_size, 1, vocab_size)
155168

156169
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
157-
170+
confidence = predictions[-1][-1][int(predicted_id)]
158171
output_array = output_array.write(t + 1, predicted_id[0])
159-
172+
confidence_array = confidence_array.write(t + 1, confidence)
160173
if predicted_id == end_token:
161174
break
162175
output = tf.transpose(output_array.stack())
163-
return output
164176

177+
return output, confidence_array.stack()
178+
179+
180+
def detokenize_output(predicted_array: tf.Tensor) -> str:
181+
"""
182+
This function takes the predicted array of tokens and returns the predicted SMILES
183+
string.
184+
185+
Args:
186+
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
165187
166-
def detokenize_output(predicted_array):
188+
Returns:
189+
str: SMILES string
190+
"""
167191
outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()]
168192
prediction = (
169193
"".join([str(elem) for elem in outputs])
170194
.replace("<start>", "")
171195
.replace("<end>", "")
172196
)
173-
174197
return prediction
175198

176199

200+
def detokenize_output_add_confidence(
201+
predicted_array: tf.Tensor,
202+
confidence_array: tf.Tensor,
203+
) -> List[Tuple[str, float]]:
204+
"""
205+
This function takes the predicted array of tokens as well as the confidence values
206+
returned by the Transformer Decoder and returns a list of tuples
207+
that contain each token of the predicted SMILES string and the confidence
208+
value.
209+
210+
Args:
211+
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
212+
213+
Returns:
214+
str: SMILES string
215+
"""
216+
prediction_with_confidence = [
217+
(
218+
tokenizer.index_word[predicted_array[0].numpy()[i]],
219+
confidence_array[i].numpy(),
220+
)
221+
for i in range(len(confidence_array))
222+
]
223+
decoded_prediction_with_confidence = list(
224+
[(utils.decoder(tok), conf) for tok, conf in prediction_with_confidence[1:-1]]
225+
)
226+
decoded_prediction_with_confidence.append(prediction_with_confidence[-1])
227+
return decoded_prediction_with_confidence
228+
229+
177230
# Initiate the DECIMER class
178231
DECIMER = DECIMER_Predictor(encoder, tokenizer, transformer, MAX_LEN)
179232

180233

181-
def predict_SMILES(image_path: str):
234+
def predict_SMILES_with_confidence(image_path: str) -> List[Tuple[str, float]]:
235+
"""
236+
This function takes an image path (str) and returns a list of tuples
237+
that contain each token of the predicted SMILES string and the confidence
238+
level from the last layer of the Transformer decoder.
239+
240+
Args:
241+
image_path (str): Path of chemical structure depiction image
242+
243+
Returns:
244+
(List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
245+
values of the predicted SMILES
246+
"""
247+
decodedImage = config.decode_image(image_path)
248+
predicted_tokens, confidence_values = DECIMER(tf.constant(decodedImage))
249+
predicted_SMILES_with_confidence = detokenize_output_add_confidence(
250+
predicted_tokens, confidence_values
251+
)
252+
return predicted_SMILES_with_confidence
253+
254+
255+
def predict_SMILES(image_path: str) -> str:
182256
"""
183257
This function takes an image path (str) and returns the SMILES
184258
representation of the depicted molecule (str).
@@ -190,9 +264,8 @@ def predict_SMILES(image_path: str):
190264
(str): SMILES representation of the molecule in the input image
191265
"""
192266
decodedImage = config.decode_image(image_path)
193-
predicted_tokens = DECIMER(tf.constant(decodedImage))
267+
predicted_tokens, _ = DECIMER(tf.constant(decodedImage))
194268
predicted_SMILES = utils.decoder(detokenize_output(predicted_tokens))
195-
196269
return predicted_SMILES
197270

198271

DECIMER/Repack_model.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import os
22
import tensorflow as tf
3-
3+
from typing import List, Tuple
44
import pickle
55
import DECIMER.Efficient_Net_encoder as Efficient_Net_encoder
66
import DECIMER.Transformer_decoder as Transformer_decoder
77
import DECIMER.config as config
8-
import DECIMER.utils as utils
98

109
print(tf.__version__)
1110

@@ -84,17 +83,56 @@
8483
start_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])
8584

8685

87-
def detokenize_output(predicted_array):
86+
def detokenize_output(predicted_array: tf.Tensor) -> str:
87+
"""
88+
This function takes the predicted array of tokens and returns the predicted SMILES
89+
string.
90+
91+
Args:
92+
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
93+
94+
Returns:
95+
str: SMILES string
96+
"""
8897
outputs = [tokenizer.index_word[i] for i in predicted_array[0].numpy()]
8998
prediction = (
9099
"".join([str(elem) for elem in outputs])
91100
.replace("<start>", "")
92101
.replace("<end>", "")
93102
)
94-
95103
return prediction
96104

97105

106+
def detokenize_output_add_confidence(
107+
predicted_array: tf.Tensor,
108+
confidence_array: tf.Tensor,
109+
) -> List[Tuple[str, float]]:
110+
"""
111+
This function takes the predicted array of tokens as well as the confidence values
112+
returned by the Transformer Decoder and returns a list of tuples
113+
that contain each token of the predicted SMILES string and the confidence
114+
value.
115+
116+
Args:
117+
predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
118+
119+
Returns:
120+
str: SMILES string
121+
"""
122+
prediction_with_confidence = [
123+
(
124+
tokenizer.index_word[predicted_array[0].numpy()[i]],
125+
confidence_array[i].numpy(),
126+
)
127+
for i in range(len(confidence_array))
128+
]
129+
decoded_prediction_with_confidence = list(
130+
[(utils.decoder(tok), conf) for tok, conf in prediction_with_confidence[1:-1]]
131+
)
132+
decoded_prediction_with_confidence.append(prediction_with_confidence[-1])
133+
return decoded_prediction_with_confidence
134+
135+
98136
class DECIMER_Predictor(tf.Module):
99137
"""This is a class which takes care of inference. It loads the saved checkpoint and the necessary
100138
tokenizers. The inference begins with the start token (<start>) and ends when the end token(<end>)
@@ -128,8 +166,6 @@ def __call__(self, Decoded_image):
128166
output (tf.Tensor[tf.int64]): predicted output as an array.
129167
"""
130168
assert isinstance(Decoded_image, tf.Tensor)
131-
if len(Decoded_image.shape) == 0:
132-
sentence = Decoded_image[tf.newaxis]
133169

134170
_image_batch = tf.expand_dims(Decoded_image, 0)
135171
_image_embedding = encoder(_image_batch, training=False)
@@ -143,6 +179,7 @@ def __call__(self, Decoded_image):
143179

144180
output_array = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
145181
output_array = output_array.write(0, start_token)
182+
confidence_array = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
146183

147184
for t in tf.range(max_length):
148185
output = tf.transpose(output_array.stack())
@@ -157,14 +194,14 @@ def __call__(self, Decoded_image):
157194
predictions = prediction_batch[:, -1:, :] # (batch_size, 1, vocab_size)
158195

159196
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
160-
197+
confidence = predictions[0, 0, int(predicted_id[0, 0])]
161198
output_array = output_array.write(t + 1, predicted_id[0])
162-
199+
confidence_array = confidence_array.write(t + 1, confidence)
163200
if predicted_id == end_token:
164201
break
165202
output = tf.transpose(output_array.stack())
166203

167-
return output
204+
return output, confidence_array.stack()
168205

169206

170207
DECIMER = DECIMER_Predictor(encoder, tokenizer, transformer, MAX_LEN)

DECIMER/Transformer_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def __init__(
185185
for _ in range(num_layers)
186186
]
187187
self.dropout = tf.keras.layers.Dropout(rate)
188-
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
188+
self.final_layer = tf.keras.layers.Dense(
189+
target_vocab_size, activation="softmax"
190+
)
189191

190192
def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
191193
seq_len = tf.shape(x)[1]

DECIMER/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
"""
4-
DECIMER V2.3.1 Python Package.
4+
DECIMER V2.4.0 Python Package.
55
============================
66
77
This repository contains DECIMER-V2,Deep lEarning for Chemical ImagE Recognition) project
@@ -23,11 +23,11 @@
2323
2424
"""
2525

26-
__version__ = "2.3.1"
26+
__version__ = "2.4.0"
2727

2828
__all__ = [
2929
"DECIMER",
3030
]
3131

3232

33-
from .decimer import predict_SMILES
33+
from .decimer import predict_SMILES, predict_SMILES_with_confidence

0 commit comments

Comments
 (0)