@@ -962,48 +962,53 @@ def testOnBatch(self, X, Y, accuracy=True, out_name=None):
962
962
963
963
964
964
def predict_cond (self , X , states_below , params , ii ):
965
- p = []
966
- for state_below in states_below :
967
- X ['state_below' ] = state_below .reshape (1 ,- 1 )
968
- #data = self.model_test.predict_on_batch(X)
969
- data = self .model .predict_on_batch (X )
970
- # Get probs of all words in the current timestep
971
- if len (params ['model_outputs' ]) > 1 :
972
- all_data = {}
973
- for output_id in range (len (params ['model_outputs' ])):
974
- print output_id , data [output_id ]
975
- all_data [params ['model_outputs' ][output_id ]] = data [output_id ]
976
- all_data [params ['model_outputs' ][0 ]] = np .array (all_data [params ['model_outputs' ][0 ]])[:, ii , :]
977
- else :
978
- all_data = {params ['model_outputs' ][0 ]: np .array (data )[:, ii , :]}
979
- # Append the prob distribution
980
- p .append (all_data [params ['model_outputs' ][0 ]])
981
- p = np .asarray (p )
982
- return p [:, 0 , :]
965
+ x = {}
966
+ n_samples = states_below .shape [0 ]
967
+ for model_input in params ['model_inputs' ]:
968
+ if model_input is not 'state_below' :
969
+ if X [model_input ].shape [0 ] == 1 :
970
+ #TODO: More generic inputs (e.g. for 1D-2D inputs)
971
+ x [model_input ] = np .repeat (X [model_input ], n_samples , axis = 0 ).reshape ((n_samples , X [model_input ].shape [1 ],
972
+ X [model_input ].shape [2 ]))
973
+ x ['state_below' ] = states_below
974
+ data = self .model .predict_on_batch (x )
975
+ if len (params ['model_outputs' ]) > 1 :
976
+ all_data = {}
977
+ for output_id in range (len (params ['model_outputs' ])):
978
+ all_data [params ['model_outputs' ][output_id ]] = data [output_id ]
979
+ all_data [params ['model_outputs' ][0 ]] = np .array (all_data [params ['model_outputs' ][0 ]])[:, ii , :]
980
+ else :
981
+ all_data = {params ['model_outputs' ][0 ]: np .array (data )[:, ii , :]}
982
+
983
+ return all_data [params ['model_outputs' ][0 ]]
983
984
984
985
def beam_search (self , X , params , null_sym = 2 ):
985
986
986
- k = params ['beam_size' ]
987
- sample = []
988
- sample_score = []
987
+ k = params ['beam_size' ] + 1
988
+ samples = []
989
+ sample_scores = []
989
990
990
991
dead_k = 0 # samples that reached eos
991
992
live_k = 1 # samples that did not yet reached eos
992
993
hyp_samples = [[]] * live_k
993
994
hyp_scores = np .zeros (live_k ).astype ('float32' )
994
- state_below = np .asarray ([np . zeros ( params [ 'maxlen' ]) + null_sym ]* live_k )
995
+ state_below = np .asarray ([null_sym ] * live_k )
995
996
for ii in xrange (params ['maxlen' ]):
996
997
# for every possible live sample calc prob for every possible label
997
998
probs = self .predict_cond (X , state_below , params , ii )
998
-
999
999
# total score for every sample is sum of -log of word prb
1000
1000
cand_scores = np .array (hyp_scores )[:, None ] - np .log (probs )
1001
1001
cand_flat = cand_scores .flatten ()
1002
+ # Find the best options by calling argsort of flatten array
1002
1003
ranks_flat = cand_flat .argsort ()[:(k - dead_k )]
1004
+
1005
+ # Decypher flatten indices
1003
1006
voc_size = probs .shape [1 ]
1004
1007
trans_indices = ranks_flat / voc_size # index of row
1005
1008
word_indices = ranks_flat % voc_size # index of col
1006
1009
costs = cand_flat [ranks_flat ]
1010
+
1011
+ # Form a beam for the next iteration
1007
1012
new_hyp_samples = []
1008
1013
new_hyp_scores = np .zeros (k - dead_k ).astype ('float32' )
1009
1014
for idx , [ti , wi ] in enumerate (zip (trans_indices , word_indices )):
@@ -1016,8 +1021,8 @@ def beam_search(self, X, params, null_sym=2):
1016
1021
hyp_scores = []
1017
1022
for idx in xrange (len (new_hyp_samples )):
1018
1023
if new_hyp_samples [idx ][- 1 ] == 0 :
1019
- sample .append (new_hyp_samples [idx ])
1020
- sample_score .append (new_hyp_scores [idx ])
1024
+ samples .append (new_hyp_samples [idx ])
1025
+ sample_scores .append (new_hyp_scores [idx ])
1021
1026
dead_k += 1
1022
1027
else :
1023
1028
new_live_k += 1
@@ -1032,13 +1037,14 @@ def beam_search(self, X, params, null_sym=2):
1032
1037
break
1033
1038
state_below = np .asarray (hyp_samples , dtype = 'int64' )
1034
1039
state_below = np .hstack ((np .zeros ((state_below .shape [0 ], 1 ), dtype = 'int64' )+ null_sym , state_below ))
1035
- #np.zeros((state_below.shape[0], ii),dtype='int64')))
1040
+
1036
1041
# dump every remaining one
1037
1042
if live_k > 0 :
1038
1043
for idx in xrange (live_k ):
1039
- sample .append (hyp_samples [idx ])
1040
- sample_score .append (hyp_scores [idx ])
1041
- return sample , sample_score
1044
+ samples .append (hyp_samples [idx ])
1045
+ sample_scores .append (hyp_scores [idx ])
1046
+
1047
+ return samples , sample_scores
1042
1048
1043
1049
def BeamSearchNet (self , ds , parameters ):
1044
1050
'''
@@ -1061,6 +1067,7 @@ def BeamSearchNet(self, ds, parameters):
1061
1067
'model_outputs' : ['description' ],
1062
1068
'dataset_inputs' : ['source_text' , 'state_below' ],
1063
1069
'dataset_outputs' : ['description' ],
1070
+ 'normalize' : False ,
1064
1071
'sampling_type' : 'max_likelihood'
1065
1072
}
1066
1073
params = self .checkParameters (parameters , default_params )
@@ -1124,14 +1131,20 @@ def BeamSearchNet(self, ds, parameters):
1124
1131
for input_id in params ['dataset_inputs' ]:
1125
1132
x [input_id ] = np .asarray ([X [input_id ][i ]])
1126
1133
samples , scores = self .beam_search (x , params , null_sym = ds .extra_words ['<null>' ])
1127
- out .append (samples [0 ])
1128
- total_cost += scores [0 ]
1134
+ if params ['normalize' ]:
1135
+ counts = [len (sample ) for sample in samples ]
1136
+ scores = [co / cn for co , cn in zip (scores , counts )]
1137
+ best_score = np .argmin (scores )
1138
+ best_sample = samples [best_score ]
1139
+ out .append (best_sample )
1140
+ total_cost += scores [best_score ]
1129
1141
eta = (n_samples - sampled ) * (time .time () - start_time ) / sampled
1130
1142
if params ['n_samples' ] > 0 :
1131
1143
for output_id in params ['dataset_outputs' ]:
1132
1144
references .append (Y [output_id ][i ])
1133
- sys .stdout .write ('Cost of the translations: %f\n ' % scores [ 0 ] )
1145
+ sys .stdout .write ('Cost of the translations: %f\n ' % total_cost )
1134
1146
sys .stdout .flush ()
1147
+
1135
1148
predictions [s ] = np .asarray (out )
1136
1149
1137
1150
if params ['n_samples' ] < 1 :
0 commit comments