-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
在rerank中的run_marco.py程序中有这样一段代码:
pred_scores = output.predictions
# pred_scores = trainer.predict(test_dataset=test_dataset).predictions
if trainer.is_world_process_zero():
# assert len(pred_qids) == len(pred_scores)
with open(data_args.rank_score_path, "w") as writer:
for i,(qid, pid, score) in enumerate(zip(pred_qids, pred_pids, pred_scores)):
writer.write(f'{qid} Q0 {pid} {i} {score} {data_args.run_id}\n')
但是实际上pred_scores并不是对应的embedding,实际上pred_scores[0]对应的是分数,pred_scores[1]才是对应的embedding,所以代码应该改成下面:
pred_scores = output.predictions
# pred_scores = trainer.predict(test_dataset=test_dataset).predictions
if trainer.is_world_process_zero():
# assert len(pred_qids) == len(pred_scores)
with open(data_args.rank_score_path, "w") as writer:
for i,(qid, pid, score, embeds) in enumerate(zip(pred_qids, pred_pids, pred_scores[0], pred_scores[1])):
score_str = ''.join(map(str, score.flatten()))
embeds_str = ''
embeds_str = ', '.join(map(str, embeds.flatten()))
writer.write(f'{qid} Q0 {pid} {i} {score_str} {embeds_str} {data_args.run_id}\n')
# writer.write(f'{qid}\t{pid}\t{score}\n')
Metadata
Metadata
Assignees
Labels
No labels