7
7
import argparse
8
8
9
9
MODEL_CONFIGS = {
10
- "albert " : {
11
- "model_name" : "orafandina/albert -base-v2-finetuned -qnli" ,
12
- "tokenizer_name" : "albert/albert -base-v2 "
10
+ "bart " : {
11
+ "model_name" : "ModelTC/bart -base-qnli" ,
12
+ "tokenizer_name" : "facebook/bart -base"
13
13
},
14
14
"bert" : {
15
15
"model_name" : "Li/bert-base-uncased-qnli" ,
23
23
"model_name" : "tmnam20/xlm-roberta-large-qnli-1" ,
24
24
"tokenizer_name" : "FacebookAI/xlm-roberta-large"
25
25
},
26
- "distilbert " : {
27
- "model_name" : "anirudh21/distilbert-base-uncased-finetuned-qnli " ,
28
- "tokenizer_name" : "distilbert/distilbert-base-uncased "
26
+ "gpt2 " : {
27
+ "model_name" : "tanganke/gpt2_qnli " ,
28
+ "tokenizer_name" : "openai-community/gpt2 "
29
29
},
30
30
"t5" : {
31
31
"model_name" : "lightsout19/t5-small-qnli" ,
32
32
"tokenizer_name" : "google-t5/t5-small"
33
33
},
34
- "gpt2 " : {
35
- "model_name" : "tanganke/gpt2_qnli " ,
36
- "tokenizer_name" : "openai-community/gpt2 "
34
+ "distilbert " : {
35
+ "model_name" : "anirudh21/distilbert-base-uncased-finetuned-qnli " ,
36
+ "tokenizer_name" : "distilbert/distilbert-base-uncased "
37
37
},
38
38
"llama" : {
39
39
"model_name" : "Cheng98/llama-160m-qnli" ,
40
40
"tokenizer_name" : "JackFram/llama-160m"
41
41
},
42
+ "albert" : {
43
+ "model_name" : "orafandina/albert-base-v2-finetuned-qnli" ,
44
+ "tokenizer_name" : "albert/albert-base-v2"
45
+ },
42
46
"opt" : {
43
- "model_name" : "facebook/opt-125m" ,
44
- "tokenizer_name" : "utahnlp/qnli_facebook_opt-125m_seed-1"
47
+ "model_name" : "utahnlp/qnli_facebook_opt-125m_seed-1" ,
48
+ "tokenizer_name" : "facebook/opt-125m"
49
+ },
50
+ "llama" : {
51
+ "model_name" : "Cheng98/llama-160m-qnli" ,
52
+ "tokenizer_name" : "JackFram/llama-160m"
45
53
},
46
- "bart" : {
47
- "model_name" : "facebook/bart-large-qnli" ,
48
- "tokenizer_name" : "ModelTC/bart-base-qnli"
49
- }
50
54
}
51
55
52
56
def get_model_and_tokenizer (model_type ):
@@ -63,31 +67,29 @@ def get_model_and_tokenizer(model_type):
63
67
def predict_qnli (model , tokenizer , question , sentence ):
64
68
"""预测QNLI任务"""
65
69
inputs = tokenizer (question , sentence , return_tensors = "ms" , truncation = True , max_length = 512 )
66
- outputs = model (** inputs )
70
+ input_ids = inputs ["input_ids" ]
71
+ attention_mask = inputs ["attention_mask" ]
72
+ outputs = model (input_ids = input_ids , attention_mask = attention_mask )
67
73
logits = outputs .logits
68
74
return logits .argmax (axis = 1 ).asnumpy ()[0 ]
69
75
70
76
def evaluate_model (model_type , data_path ):
71
77
"""评估模型在QNLI数据集上的表现"""
72
78
print (f"正在评估模型: { model_type } " )
73
79
74
- # 加载模型和分词器
75
80
model , tokenizer = get_model_and_tokenizer (model_type )
76
81
print (f"模型类型: { model .config .model_type } " )
77
82
78
- # 加载数据
79
- df = pd .read_csv (data_path , sep = '\t ' , header = 0 , names = ['idx' , 'question' , 'sentence' , 'label' ])
83
+ df = pd .read_csv (data_path , sep = '\t ' , header = 0 , names = ['idx' , 'question' , 'sentence' , 'label' ], on_bad_lines = 'skip' )
80
84
df = df .dropna (subset = ['label' ])
81
85
82
- # 标签映射
83
86
label_map = {'entailment' : 0 , 'not_entailment' : 1 }
84
87
valid_data = df [df ['label' ].isin (label_map .keys ())]
85
88
86
89
questions = valid_data ['question' ].tolist ()
87
90
sentences = valid_data ['sentence' ].tolist ()
88
91
labels = [label_map [label ] for label in valid_data ['label' ]]
89
92
90
- # 预测和评估
91
93
predict_true = 0
92
94
for question , sentence , true_label in tqdm (zip (questions , sentences , labels ),
93
95
total = len (questions ),
@@ -96,7 +98,6 @@ def evaluate_model(model_type, data_path):
96
98
if pred_label == true_label :
97
99
predict_true += 1
98
100
99
- # 输出结果
100
101
accuracy = float (predict_true / len (questions ) * 100 )
101
102
print (f"测试集总样本数: { len (questions )} " )
102
103
print (f"预测正确的数量: { predict_true } " )
0 commit comments