Skip to content

Commit d059093

Browse files
authored
Merge pull request #17 from lvapeab/master
Beam search improvements + bug
2 parents 310bfe9 + 41ca382 commit d059093

File tree

1 file changed

+45
-32
lines changed

1 file changed

+45
-32
lines changed

keras_wrapper/cnn_model.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -962,48 +962,53 @@ def testOnBatch(self, X, Y, accuracy=True, out_name=None):
962962

963963

964964
def predict_cond(self, X, states_below, params, ii):
965-
p = []
966-
for state_below in states_below:
967-
X['state_below'] = state_below.reshape(1,-1)
968-
#data = self.model_test.predict_on_batch(X)
969-
data = self.model.predict_on_batch(X)
970-
# Get probs of all words in the current timestep
971-
if len(params['model_outputs']) > 1:
972-
all_data = {}
973-
for output_id in range(len(params['model_outputs'])):
974-
print output_id, data[output_id]
975-
all_data[params['model_outputs'][output_id]] = data[output_id]
976-
all_data[params['model_outputs'][0]] = np.array(all_data[params['model_outputs'][0]])[:, ii, :]
977-
else:
978-
all_data = {params['model_outputs'][0]: np.array(data)[:, ii, :]}
979-
# Append the prob distribution
980-
p.append(all_data[params['model_outputs'][0]])
981-
p = np.asarray(p)
982-
return p[:, 0, :]
965+
x = {}
966+
n_samples = states_below.shape[0]
967+
for model_input in params['model_inputs']:
968+
if model_input is not 'state_below':
969+
if X[model_input].shape[0] == 1:
970+
#TODO: More generic inputs (e.g. for 1D-2D inputs)
971+
x[model_input] = np.repeat(X[model_input], n_samples, axis=0).reshape((n_samples, X[model_input].shape[1],
972+
X[model_input].shape[2]))
973+
x['state_below'] = states_below
974+
data = self.model.predict_on_batch(x)
975+
if len(params['model_outputs']) > 1:
976+
all_data = {}
977+
for output_id in range(len(params['model_outputs'])):
978+
all_data[params['model_outputs'][output_id]] = data[output_id]
979+
all_data[params['model_outputs'][0]] = np.array(all_data[params['model_outputs'][0]])[:, ii, :]
980+
else:
981+
all_data = {params['model_outputs'][0]: np.array(data)[:, ii, :]}
982+
983+
return all_data[params['model_outputs'][0]]
983984

984985
def beam_search(self, X, params, null_sym=2):
985986

986-
k = params['beam_size']
987-
sample = []
988-
sample_score = []
987+
k = params['beam_size'] + 1
988+
samples = []
989+
sample_scores = []
989990

990991
dead_k = 0 # samples that reached eos
991992
live_k = 1 # samples that did not yet reached eos
992993
hyp_samples = [[]] * live_k
993994
hyp_scores = np.zeros(live_k).astype('float32')
994-
state_below = np.asarray([np.zeros(params['maxlen'])+null_sym]* live_k)
995+
state_below = np.asarray([null_sym] * live_k)
995996
for ii in xrange(params['maxlen']):
996997
# for every possible live sample calc prob for every possible label
997998
probs = self.predict_cond(X, state_below, params, ii)
998-
999999
# total score for every sample is sum of -log of word prb
10001000
cand_scores = np.array(hyp_scores)[:, None] - np.log(probs)
10011001
cand_flat = cand_scores.flatten()
1002+
# Find the best options by calling argsort of flatten array
10021003
ranks_flat = cand_flat.argsort()[:(k-dead_k)]
1004+
1005+
# Decypher flatten indices
10031006
voc_size = probs.shape[1]
10041007
trans_indices = ranks_flat / voc_size # index of row
10051008
word_indices = ranks_flat % voc_size # index of col
10061009
costs = cand_flat[ranks_flat]
1010+
1011+
# Form a beam for the next iteration
10071012
new_hyp_samples = []
10081013
new_hyp_scores = np.zeros(k-dead_k).astype('float32')
10091014
for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
@@ -1016,8 +1021,8 @@ def beam_search(self, X, params, null_sym=2):
10161021
hyp_scores = []
10171022
for idx in xrange(len(new_hyp_samples)):
10181023
if new_hyp_samples[idx][-1] == 0:
1019-
sample.append(new_hyp_samples[idx])
1020-
sample_score.append(new_hyp_scores[idx])
1024+
samples.append(new_hyp_samples[idx])
1025+
sample_scores.append(new_hyp_scores[idx])
10211026
dead_k += 1
10221027
else:
10231028
new_live_k += 1
@@ -1032,13 +1037,14 @@ def beam_search(self, X, params, null_sym=2):
10321037
break
10331038
state_below = np.asarray(hyp_samples, dtype='int64')
10341039
state_below = np.hstack((np.zeros((state_below.shape[0], 1), dtype='int64')+null_sym, state_below))
1035-
#np.zeros((state_below.shape[0], ii),dtype='int64')))
1040+
10361041
# dump every remaining one
10371042
if live_k > 0:
10381043
for idx in xrange(live_k):
1039-
sample.append(hyp_samples[idx])
1040-
sample_score.append(hyp_scores[idx])
1041-
return sample, sample_score
1044+
samples.append(hyp_samples[idx])
1045+
sample_scores.append(hyp_scores[idx])
1046+
1047+
return samples, sample_scores
10421048

10431049
def BeamSearchNet(self, ds, parameters):
10441050
'''
@@ -1061,6 +1067,7 @@ def BeamSearchNet(self, ds, parameters):
10611067
'model_outputs': ['description'],
10621068
'dataset_inputs': ['source_text', 'state_below'],
10631069
'dataset_outputs': ['description'],
1070+
'normalize': False,
10641071
'sampling_type': 'max_likelihood'
10651072
}
10661073
params = self.checkParameters(parameters, default_params)
@@ -1124,14 +1131,20 @@ def BeamSearchNet(self, ds, parameters):
11241131
for input_id in params['dataset_inputs']:
11251132
x[input_id] = np.asarray([X[input_id][i]])
11261133
samples, scores = self.beam_search(x, params, null_sym=ds.extra_words['<null>'])
1127-
out.append(samples[0])
1128-
total_cost += scores[0]
1134+
if params['normalize']:
1135+
counts = [len(sample) for sample in samples]
1136+
scores = [co / cn for co, cn in zip(scores, counts)]
1137+
best_score = np.argmin(scores)
1138+
best_sample = samples[best_score]
1139+
out.append(best_sample)
1140+
total_cost += scores[best_score]
11291141
eta = (n_samples - sampled) * (time.time() - start_time) / sampled
11301142
if params['n_samples'] > 0:
11311143
for output_id in params['dataset_outputs']:
11321144
references.append(Y[output_id][i])
1133-
sys.stdout.write('Cost of the translations: %f\n'%scores[0])
1145+
sys.stdout.write('Cost of the translations: %f\n'%total_cost)
11341146
sys.stdout.flush()
1147+
11351148
predictions[s] = np.asarray(out)
11361149

11371150
if params['n_samples'] < 1:

0 commit comments

Comments
 (0)