Skip to content

Commit c2371b1

Browse files
authored
#benchmark: add GLUE-QNLI benchmark, including 10 models inference accuracy comparsion (#1865)
1 parent 1a66f9f commit c2371b1

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

benchmark/GLUE-QNLI/README.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# GLUE-QNLI
2+
A repository comparing the inference accuracy of MindNLP and Transformer on the GLUE QNLI dataset
3+
4+
+ ## Dataset
5+
+ The QNLI (Question Natural Language Inference) dataset is part of the GLUE benchmark. It is converted from the Stanford Question Answering Dataset (SQuAD).
6+
+
7+
+ ### Getting the Dataset
8+
+ 1. Visit [GLUE Benchmark Tasks](https://gluebenchmark.com/tasks/)
9+
+ 2. Register/Login to download the GLUE data
10+
+ 3. Download and extract the QNLI dataset
11+
+ 4. Place the following files in the `mindnlp/benchmark/GLUE-QNLI/` directory:
12+
+ - dev.tsv (Development set)
13+
+ - test.tsv (Test set)
14+
+ - train.tsv (Training set)
15+
+
16+
+ The QNLI task is a binary classification task derived from SQuAD, where the goal is to determine whether a given context sentence contains the answer to a given question.
17+
18+
## Quick Start
19+
20+
### Installation
21+
To get started with this project, follow these steps:
22+
23+
1. **Create a conda environment (optional but recommended):**
24+
```bash
25+
conda create -n mindnlp python==3.9
26+
conda activate mindnlp
27+
2. **Install the dependencies:**
28+
Please note that mindnlp is in the Ascend environment, while transformers is in the GPU environment, and the required dependencies are in the requirements of their respective folders.
29+
```bash
30+
pip install -r requirements.txt
31+
3. **Usage**
32+
Once the installation is complete, you can choose use differnet models to start inference. Here's how to run the inference:
33+
```bash
34+
# Evaluate specific model using default dataset (dev.tsv)
35+
python model_QNLI.py --model albert
36+
37+
# Evaluate with custom dataset
38+
python model_QNLI.py --model bert --data ./QNLI/dev.tsv
39+
```
40+
Supported model options: `albert`, `bert`, `roberta`, `xlm-roberta`, `distilbert`, `t5`, `gpt2`, `llama`, `opt`, `bart`
41+
42+
## Accuracy Comparsion
43+
| Model Name | bart | bert | roberta | xlm-roberta | gpt2 | t5 | distilbert | albert | opt | llama |
44+
|---|---|---|---|---|---|---|---|---|---|---|
45+
| Base Model | facebook/bart-base | google-bert/bert-base-uncased | FacebookAI/roberta-large | FacebookAI/xlm-roberta-large | openai-community/gpt2 | google-t5/t5-small | distilbert/distilbert-base-uncased | albert/albert-base-v2 | facebook/opt-125m | JackFram/llama-160m |
46+
| Fine-tuned Model(hf) | ModelTC/bart-base-qnli | Li/bert-base-uncased-qnli | howey/roberta-large-qnli | tmnam20/xlm-roberta-large-qnli-1 | tanganke/gpt2_qnli | lightsout19/t5-small-qnli | anirudh21/distilbert-base-uncased-finetuned-qnli | orafandina/albert-base-v2-finetuned-qnli | utahnlp/qnli_facebook_opt-125m_seed-1 | Cheng98/llama-160m-qnli |
47+
| transformers accuracy(GPU) | 92.29 | 67.43 | 94.50 | 92.50 | 88.15 | 89.71 | 59.21 | 55.14 | 86.10 | 50.97 |
48+
| mindnlp accuracy(NPU) | 92.29 | 67.43 | 94.51 | 92.50 | 88.15 | 89.71 | 59.23 | 55.13 | 86.10 | 50.97 |

benchmark/GLUE-QNLI/model_QNLI.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import pandas as pd
2+
from mindnlp.transformers import (
3+
AutoTokenizer,
4+
AutoModelForSequenceClassification
5+
)
6+
from tqdm import tqdm
7+
import argparse
8+
9+
MODEL_CONFIGS = {
10+
"albert": {
11+
"model_name": "orafandina/albert-base-v2-finetuned-qnli",
12+
"tokenizer_name": "albert/albert-base-v2"
13+
},
14+
"bert": {
15+
"model_name": "Li/bert-base-uncased-qnli",
16+
"tokenizer_name": "google-bert/bert-base-uncased"
17+
},
18+
"roberta": {
19+
"model_name": "howey/roberta-large-qnli",
20+
"tokenizer_name": "FacebookAI/roberta-large"
21+
},
22+
"xlm-roberta": {
23+
"model_name": "tmnam20/xlm-roberta-large-qnli-1",
24+
"tokenizer_name": "FacebookAI/xlm-roberta-large"
25+
},
26+
"distilbert": {
27+
"model_name": "anirudh21/distilbert-base-uncased-finetuned-qnli",
28+
"tokenizer_name": "distilbert/distilbert-base-uncased"
29+
},
30+
"t5": {
31+
"model_name": "lightsout19/t5-small-qnli",
32+
"tokenizer_name": "google-t5/t5-small"
33+
},
34+
"gpt2": {
35+
"model_name": "tanganke/gpt2_qnli",
36+
"tokenizer_name": "openai-community/gpt2"
37+
},
38+
"llama": {
39+
"model_name": "Cheng98/llama-160m-qnli",
40+
"tokenizer_name": "JackFram/llama-160m"
41+
},
42+
"opt": {
43+
"model_name": "facebook/opt-125m",
44+
"tokenizer_name": "utahnlp/qnli_facebook_opt-125m_seed-1"
45+
},
46+
"bart": {
47+
"model_name": "facebook/bart-large-qnli",
48+
"tokenizer_name": "ModelTC/bart-base-qnli"
49+
}
50+
}
51+
52+
def get_model_and_tokenizer(model_type):
53+
"""获取指定类型的模型和分词器"""
54+
if model_type not in MODEL_CONFIGS:
55+
raise ValueError(f"不支持的模型类型: {model_type}")
56+
57+
config = MODEL_CONFIGS[model_type]
58+
tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_name"])
59+
model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], num_labels=2)
60+
61+
return model, tokenizer
62+
63+
def predict_qnli(model, tokenizer, question, sentence):
64+
"""预测QNLI任务"""
65+
inputs = tokenizer(question, sentence, return_tensors="ms", truncation=True, max_length=512)
66+
outputs = model(**inputs)
67+
logits = outputs.logits
68+
return logits.argmax(axis=1).asnumpy()[0]
69+
70+
def evaluate_model(model_type, data_path):
71+
"""评估模型在QNLI数据集上的表现"""
72+
print(f"正在评估模型: {model_type}")
73+
74+
# 加载模型和分词器
75+
model, tokenizer = get_model_and_tokenizer(model_type)
76+
print(f"模型类型: {model.config.model_type}")
77+
78+
# 加载数据
79+
df = pd.read_csv(data_path, sep='\t', header=0, names=['idx', 'question', 'sentence', 'label'])
80+
df = df.dropna(subset=['label'])
81+
82+
# 标签映射
83+
label_map = {'entailment': 0, 'not_entailment': 1}
84+
valid_data = df[df['label'].isin(label_map.keys())]
85+
86+
questions = valid_data['question'].tolist()
87+
sentences = valid_data['sentence'].tolist()
88+
labels = [label_map[label] for label in valid_data['label']]
89+
90+
# 预测和评估
91+
predict_true = 0
92+
for question, sentence, true_label in tqdm(zip(questions, sentences, labels),
93+
total=len(questions),
94+
desc="预测进度"):
95+
pred_label = predict_qnli(model, tokenizer, question, sentence)
96+
if pred_label == true_label:
97+
predict_true += 1
98+
99+
# 输出结果
100+
accuracy = float(predict_true / len(questions) * 100)
101+
print(f"测试集总样本数: {len(questions)}")
102+
print(f"预测正确的数量: {predict_true}")
103+
print(f"准确率为: {accuracy:.2f}%")
104+
105+
return accuracy
106+
107+
if __name__ == '__main__':
108+
parser = argparse.ArgumentParser(description='QNLI任务评估脚本')
109+
parser.add_argument('--model', type=str, required=True,
110+
choices=list(MODEL_CONFIGS.keys()),
111+
help='要评估的模型类型')
112+
parser.add_argument('--data', type=str, default='./QNLI/dev.tsv',
113+
help='数据集路径')
114+
115+
args = parser.parse_args()
116+
evaluate_model(args.model, args.data)

benchmark/GLUE-QNLI/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mindspore==2.3.1
2+
mindnlp==0.4.1
3+
tqdm==latest
4+
pandas==latest
5+
numpy==1.26.4

0 commit comments

Comments
 (0)