Skip to content

Commit 77e97d4

Browse files
authored
#fix benchmark GLUE-QNLI fix read_csv error and predict funciton and modify readme description (#1868)
1 parent c2371b1 commit 77e97d4

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

benchmark/GLUE-QNLI/README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,18 @@ Please note that mindnlp is in the Ascend environment, while transformers is in
3232
Once the installation is complete, you can choose use differnet models to start inference. Here's how to run the inference:
3333
```bash
3434
# Evaluate specific model using default dataset (dev.tsv)
35-
python model_QNLI.py --model albert
35+
python model_QNLI.py --model bart
3636
3737
# Evaluate with custom dataset
38-
python model_QNLI.py --model bert --data ./QNLI/dev.tsv
38+
python model_QNLI.py --model bart --data ./QNLI/test.tsv
3939
```
40-
Supported model options: `albert`, `bert`, `roberta`, `xlm-roberta`, `distilbert`, `t5`, `gpt2`, `llama`, `opt`, `bart`
40+
Supported model options: `bart`, `bert`, `roberta`, `xlm-roberta`, `gpt2`, `t5`, `distilbert`, `albert`, `llama`, `opt`
4141
4242
## Accuracy Comparsion
43+
Our reproduced model performance on QNLI/dev.tsv is reported as follows.
44+
Experiments are tested on ascend 910* with mindspore 2.3.1 graph mode.
45+
All fine-tuned models are derived from open-source models provided by huggingface.
46+
4347
| Model Name | bart | bert | roberta | xlm-roberta | gpt2 | t5 | distilbert | albert | opt | llama |
4448
|---|---|---|---|---|---|---|---|---|---|---|
4549
| Base Model | facebook/bart-base | google-bert/bert-base-uncased | FacebookAI/roberta-large | FacebookAI/xlm-roberta-large | openai-community/gpt2 | google-t5/t5-small | distilbert/distilbert-base-uncased | albert/albert-base-v2 | facebook/opt-125m | JackFram/llama-160m |

benchmark/GLUE-QNLI/model_QNLI.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import argparse
88

99
MODEL_CONFIGS = {
10-
"albert": {
11-
"model_name": "orafandina/albert-base-v2-finetuned-qnli",
12-
"tokenizer_name": "albert/albert-base-v2"
10+
"bart": {
11+
"model_name": "ModelTC/bart-base-qnli",
12+
"tokenizer_name": "facebook/bart-base"
1313
},
1414
"bert": {
1515
"model_name": "Li/bert-base-uncased-qnli",
@@ -23,30 +23,34 @@
2323
"model_name": "tmnam20/xlm-roberta-large-qnli-1",
2424
"tokenizer_name": "FacebookAI/xlm-roberta-large"
2525
},
26-
"distilbert": {
27-
"model_name": "anirudh21/distilbert-base-uncased-finetuned-qnli",
28-
"tokenizer_name": "distilbert/distilbert-base-uncased"
26+
"gpt2": {
27+
"model_name": "tanganke/gpt2_qnli",
28+
"tokenizer_name": "openai-community/gpt2"
2929
},
3030
"t5": {
3131
"model_name": "lightsout19/t5-small-qnli",
3232
"tokenizer_name": "google-t5/t5-small"
3333
},
34-
"gpt2": {
35-
"model_name": "tanganke/gpt2_qnli",
36-
"tokenizer_name": "openai-community/gpt2"
34+
"distilbert": {
35+
"model_name": "anirudh21/distilbert-base-uncased-finetuned-qnli",
36+
"tokenizer_name": "distilbert/distilbert-base-uncased"
3737
},
3838
"llama": {
3939
"model_name": "Cheng98/llama-160m-qnli",
4040
"tokenizer_name": "JackFram/llama-160m"
4141
},
42+
"albert": {
43+
"model_name": "orafandina/albert-base-v2-finetuned-qnli",
44+
"tokenizer_name": "albert/albert-base-v2"
45+
},
4246
"opt": {
43-
"model_name": "facebook/opt-125m",
44-
"tokenizer_name": "utahnlp/qnli_facebook_opt-125m_seed-1"
47+
"model_name": "utahnlp/qnli_facebook_opt-125m_seed-1",
48+
"tokenizer_name": "facebook/opt-125m"
49+
},
50+
"llama": {
51+
"model_name": "Cheng98/llama-160m-qnli",
52+
"tokenizer_name": "JackFram/llama-160m"
4553
},
46-
"bart": {
47-
"model_name": "facebook/bart-large-qnli",
48-
"tokenizer_name": "ModelTC/bart-base-qnli"
49-
}
5054
}
5155

5256
def get_model_and_tokenizer(model_type):
@@ -63,31 +67,29 @@ def get_model_and_tokenizer(model_type):
6367
def predict_qnli(model, tokenizer, question, sentence):
6468
"""预测QNLI任务"""
6569
inputs = tokenizer(question, sentence, return_tensors="ms", truncation=True, max_length=512)
66-
outputs = model(**inputs)
70+
input_ids = inputs["input_ids"]
71+
attention_mask = inputs["attention_mask"]
72+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
6773
logits = outputs.logits
6874
return logits.argmax(axis=1).asnumpy()[0]
6975

7076
def evaluate_model(model_type, data_path):
7177
"""评估模型在QNLI数据集上的表现"""
7278
print(f"正在评估模型: {model_type}")
7379

74-
# 加载模型和分词器
7580
model, tokenizer = get_model_and_tokenizer(model_type)
7681
print(f"模型类型: {model.config.model_type}")
7782

78-
# 加载数据
79-
df = pd.read_csv(data_path, sep='\t', header=0, names=['idx', 'question', 'sentence', 'label'])
83+
df = pd.read_csv(data_path, sep='\t', header=0, names=['idx', 'question', 'sentence', 'label'], on_bad_lines='skip')
8084
df = df.dropna(subset=['label'])
8185

82-
# 标签映射
8386
label_map = {'entailment': 0, 'not_entailment': 1}
8487
valid_data = df[df['label'].isin(label_map.keys())]
8588

8689
questions = valid_data['question'].tolist()
8790
sentences = valid_data['sentence'].tolist()
8891
labels = [label_map[label] for label in valid_data['label']]
8992

90-
# 预测和评估
9193
predict_true = 0
9294
for question, sentence, true_label in tqdm(zip(questions, sentences, labels),
9395
total=len(questions),
@@ -96,7 +98,6 @@ def evaluate_model(model_type, data_path):
9698
if pred_label == true_label:
9799
predict_true += 1
98100

99-
# 输出结果
100101
accuracy = float(predict_true / len(questions) * 100)
101102
print(f"测试集总样本数: {len(questions)}")
102103
print(f"预测正确的数量: {predict_true}")

0 commit comments

Comments
 (0)