11import os
22import sys
33import tensorflow as tf
4-
4+ from typing import List , Tuple
55import pickle
66import pystow
7- from selfies import decoder
87import Transformer_decoder
98
109if int (tf .__version__ .split ("." )[1 ]) <= 10 :
@@ -116,6 +115,9 @@ def main():
116115 else :
117116 SMILES = predict_SMILES (sys .argv [1 ])
118117 print (SMILES )
118+ SMILES_with_confidence = predict_SMILES_with_confidence (sys .argv [1 ])
119+ for tup in SMILES_with_confidence :
120+ print (tup )
119121
120122
121123class DECIMER_Predictor (tf .Module ):
@@ -126,6 +128,16 @@ def __init__(self, encoder, tokenizer, transformer, max_length):
126128 self .max_length = max_length
127129
128130 def __call__ (self , Decoded_image ):
131+ """
132+ Run the DECIMER predictor model when called.
133+ Usage of predict_SMILES or predict_SMILES_with_confidence is recommended instead
134+
135+ Args:
136+ Decoded_image (_type_): output of config.decode_image
137+
138+ Returns:
139+ Tuple[tf.Tensor, tf.Tensor]: predicted tokens, confidence values
140+ """
129141 assert isinstance (Decoded_image , tf .Tensor )
130142
131143 _image_batch = tf .expand_dims (Decoded_image , 0 )
@@ -140,6 +152,7 @@ def __call__(self, Decoded_image):
140152
141153 output_array = tf .TensorArray (dtype = tf .int32 , size = 0 , dynamic_size = True )
142154 output_array = output_array .write (0 , start_token )
155+ confidence_array = tf .TensorArray (dtype = tf .float32 , size = 0 , dynamic_size = True )
143156
144157 for t in tf .range (max_length ):
145158 output = tf .transpose (output_array .stack ())
@@ -154,31 +167,92 @@ def __call__(self, Decoded_image):
154167 predictions = prediction_batch [:, - 1 :, :] # (batch_size, 1, vocab_size)
155168
156169 predicted_id = tf .cast (tf .argmax (predictions , axis = - 1 ), tf .int32 )
157-
170+ confidence = predictions [ - 1 ][ - 1 ][ int ( predicted_id )]
158171 output_array = output_array .write (t + 1 , predicted_id [0 ])
159-
172+ confidence_array = confidence_array . write ( t + 1 , confidence )
160173 if predicted_id == end_token :
161174 break
162175 output = tf .transpose (output_array .stack ())
163- return output
164176
177+ return output , confidence_array .stack ()
178+
179+
180+ def detokenize_output (predicted_array : tf .Tensor ) -> str :
181+ """
182+ This function takes the predicted array of tokens and returns the predicted SMILES
183+ string.
184+
185+ Args:
186+ predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
165187
166- def detokenize_output (predicted_array ):
188+ Returns:
189+ str: SMILES string
190+ """
167191 outputs = [tokenizer .index_word [i ] for i in predicted_array [0 ].numpy ()]
168192 prediction = (
169193 "" .join ([str (elem ) for elem in outputs ])
170194 .replace ("<start>" , "" )
171195 .replace ("<end>" , "" )
172196 )
173-
174197 return prediction
175198
176199
200+ def detokenize_output_add_confidence (
201+ predicted_array : tf .Tensor ,
202+ confidence_array : tf .Tensor ,
203+ ) -> List [Tuple [str , float ]]:
204+ """
205+ This function takes the predicted array of tokens as well as the confidence values
206+ returned by the Transformer Decoder and returns a list of tuples
207+ that contain each token of the predicted SMILES string and the confidence
208+ value.
209+
210+ Args:
211+ predicted_array (tf.Tensor): Transformer Decoder output array (predicted tokens)
212+
213+ Returns:
214+ str: SMILES string
215+ """
216+ prediction_with_confidence = [
217+ (
218+ tokenizer .index_word [predicted_array [0 ].numpy ()[i ]],
219+ confidence_array [i ].numpy (),
220+ )
221+ for i in range (len (confidence_array ))
222+ ]
223+ decoded_prediction_with_confidence = list (
224+ [(utils .decoder (tok ), conf ) for tok , conf in prediction_with_confidence [1 :- 1 ]]
225+ )
226+ decoded_prediction_with_confidence .append (prediction_with_confidence [- 1 ])
227+ return decoded_prediction_with_confidence
228+
229+
177230# Initiate the DECIMER class
178231DECIMER = DECIMER_Predictor (encoder , tokenizer , transformer , MAX_LEN )
179232
180233
181- def predict_SMILES (image_path : str ):
234+ def predict_SMILES_with_confidence (image_path : str ) -> List [Tuple [str , float ]]:
235+ """
236+ This function takes an image path (str) and returns a list of tuples
237+ that contain each token of the predicted SMILES string and the confidence
238+ level from the last layer of the Transformer decoder.
239+
240+ Args:
241+ image_path (str): Path of chemical structure depiction image
242+
243+ Returns:
244+ (List[Tuple[str, float]]): Tuples that contain the tokens and the confidence
245+ values of the predicted SMILES
246+ """
247+ decodedImage = config .decode_image (image_path )
248+ predicted_tokens , confidence_values = DECIMER (tf .constant (decodedImage ))
249+ predicted_SMILES_with_confidence = detokenize_output_add_confidence (
250+ predicted_tokens , confidence_values
251+ )
252+ return predicted_SMILES_with_confidence
253+
254+
255+ def predict_SMILES (image_path : str ) -> str :
182256 """
183257 This function takes an image path (str) and returns the SMILES
184258 representation of the depicted molecule (str).
@@ -190,9 +264,8 @@ def predict_SMILES(image_path: str):
190264 (str): SMILES representation of the molecule in the input image
191265 """
192266 decodedImage = config .decode_image (image_path )
193- predicted_tokens = DECIMER (tf .constant (decodedImage ))
267+ predicted_tokens , _ = DECIMER (tf .constant (decodedImage ))
194268 predicted_SMILES = utils .decoder (detokenize_output (predicted_tokens ))
195-
196269 return predicted_SMILES
197270
198271
0 commit comments