Skip to content

pred_scores #8

@WangSheng21s

Description

@WangSheng21s

在rerank中的run_marco.py程序中有这样一段代码:
pred_scores = output.predictions
# pred_scores = trainer.predict(test_dataset=test_dataset).predictions

    if trainer.is_world_process_zero():
        # assert len(pred_qids) == len(pred_scores)
        with open(data_args.rank_score_path, "w") as writer:
            for i,(qid, pid, score) in enumerate(zip(pred_qids, pred_pids, pred_scores)):
                writer.write(f'{qid} Q0 {pid} {i} {score} {data_args.run_id}\n')

但是实际上pred_scores并不是对应的embedding,实际上pred_scores[0]对应的是分数,pred_scores[1]才是对应的embedding,所以代码应该改成下面:
pred_scores = output.predictions
# pred_scores = trainer.predict(test_dataset=test_dataset).predictions

    if trainer.is_world_process_zero():
        # assert len(pred_qids) == len(pred_scores)
        with open(data_args.rank_score_path, "w") as writer:
            for i,(qid, pid, score, embeds) in enumerate(zip(pred_qids, pred_pids, pred_scores[0], pred_scores[1])):
                score_str = ''.join(map(str, score.flatten()))
                embeds_str = ''
                embeds_str = ', '.join(map(str, embeds.flatten()))
                writer.write(f'{qid} Q0 {pid} {i} {score_str} {embeds_str} {data_args.run_id}\n')
                # writer.write(f'{qid}\t{pid}\t{score}\n')

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions