Skip to content

Commit ac6e5c9

Browse files
committed
Support MathVista
1 parent 8fff25b commit ac6e5c9

File tree

7 files changed

+977
-66
lines changed

7 files changed

+977
-66
lines changed

internvl_chat/README.md

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ Coming Soon
110110

111111
**MultiModal Benchmark**
112112

113-
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP |
114-
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- | ---- |
115-
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 | 44.7 |
113+
| model | MME | MMB<sub>dev/test</sub> | MMB-CN<sub>dev/test</sub> | POPE | MMVP | MathVista |
114+
| --------------------------------------------------------------------------------- | -------------- | ---------------------- | ------------------------- | ---- | ---- | --------- |
115+
| [InternVL-Chat-V1.1](https://huggingface.co/OpenGVLab/InternVL-Chat-Chinese-V1-1) | 1672.3 / 341.1 | 76.6 / 75.4 | 71.5 / 70.1 | 87.2 | 44.7 | 34.5 |
116116

117117
| model | MMMU<sub>val/test</sub> | CMMMU<sub>val/test</sub> | Tiny<sub>LVLM</sub> | LLaVA<sub>bench</sub> | MM-Vet |
118118
| --------------------------------------------------------------------------------- | ----------------------- | ------------------------ | ------------------- | --------------------- | ------ |
@@ -291,6 +291,9 @@ data
291291
├── MMVP_VLM
292292
│ ├── MLLM_VLM Images/
293293
│ └── Questions.csv
294+
├── MathVista
295+
│ ├── annot_testmini.json
296+
│ └── AI4Math___math_vista/
294297
```
295298

296299
</details>
@@ -1005,3 +1008,32 @@ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mmvp
10051008
```
10061009

10071010
</details>
1011+
1012+
#### [MathVista](https://github.com/lupantech/MathVista)
1013+
1014+
<details>
1015+
<summary>Data Preparation</summary>
1016+
1017+
```bash
1018+
mkdir -p data/MathVista && cd data/MathVista
1019+
# Execute the following python code
1020+
# from datasets import load_dataset
1021+
# dataset = load_dataset("AI4Math/MathVista")
1022+
# dataset.save_to_disk('./MathVista')
1023+
wget https://huggingface.co/datasets/AI4Math/MathVista/raw/main/annot_testmini.json
1024+
cd ../..
1025+
```
1026+
1027+
</details>
1028+
1029+
<details>
1030+
<summary>Evaluation</summary>
1031+
1032+
```bash
1033+
# testmini set
1034+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mathvista-testmini
1035+
# test set
1036+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 sh evaluate.sh <checkpoint> mathvista-test
1037+
```
1038+
1039+
</details>
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import argparse
2+
import sys
3+
4+
import pandas as pd
5+
# !pip install python-Levenshtein
6+
from Levenshtein import distance
7+
from utilities import *
8+
9+
10+
def get_most_similar(prediction, choices):
11+
"""
12+
Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
13+
"""
14+
distances = [distance(prediction, choice) for choice in choices]
15+
ind = distances.index(min(distances))
16+
return choices[ind]
17+
# return min(choices, key=lambda choice: distance(prediction, choice))
18+
19+
20+
def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision):
21+
"""
22+
Normalize the extracted answer to match the answer type
23+
"""
24+
if question_type == 'multi_choice':
25+
# make sure the extraction is a string
26+
if isinstance(extraction, str):
27+
extraction = extraction.strip()
28+
else:
29+
try:
30+
extraction = str(extraction)
31+
except:
32+
extraction = ''
33+
34+
# extract "A" from "(A) text"
35+
letter = re.findall(r'\(([a-zA-Z])\)', extraction)
36+
if len(letter) > 0:
37+
extraction = letter[0].upper()
38+
39+
options = [chr(ord('A') + i) for i in range(len(choices))]
40+
41+
if extraction in options:
42+
# convert option letter to text, e.g. "A" -> "text"
43+
ind = options.index(extraction)
44+
extraction = choices[ind]
45+
else:
46+
# select the most similar option
47+
extraction = get_most_similar(extraction, choices)
48+
assert extraction in choices
49+
50+
elif answer_type == 'integer':
51+
try:
52+
extraction = str(int(float(extraction)))
53+
except:
54+
extraction = None
55+
56+
elif answer_type == 'float':
57+
try:
58+
extraction = str(round(float(extraction), precision))
59+
except:
60+
extraction = None
61+
62+
elif answer_type == 'list':
63+
try:
64+
extraction = str(extraction)
65+
except:
66+
extraction = None
67+
68+
return extraction
69+
70+
71+
def safe_equal(prediction, answer):
72+
"""
73+
Check if the prediction is equal to the answer, even if they are of different types
74+
"""
75+
try:
76+
if prediction == answer:
77+
return True
78+
return False
79+
except Exception as e:
80+
print(e)
81+
return False
82+
83+
84+
def get_acc_with_contion(res_pd, key, value):
85+
if key == 'skills':
86+
# if value in res_pd[key]:
87+
total_pd = res_pd[res_pd[key].apply(lambda x: value in x)]
88+
else:
89+
total_pd = res_pd[res_pd[key] == value]
90+
91+
correct_pd = total_pd[total_pd['true_false'] == True] # noqa: E712
92+
acc = '{:.2f}'.format(len(correct_pd) / len(total_pd) * 100)
93+
return len(correct_pd), len(total_pd), acc
94+
95+
96+
if __name__ == '__main__':
97+
parser = argparse.ArgumentParser()
98+
parser.add_argument('--output_dir', type=str, default='./results')
99+
parser.add_argument('--output_file', type=str, default='output.json')
100+
parser.add_argument('--score_file', type=str, default='scores.json')
101+
parser.add_argument('--gt_file', type=str, default='./data/MathVista/annot_testmini.json', help='ground truth file')
102+
parser.add_argument('--number', type=int, default=-1, help='number of problems to run')
103+
parser.add_argument('--rerun', action='store_true', help='rerun the evaluation')
104+
parser.add_argument('--caculate_gain', action='store_true', help='caculate the socre gains over random guess')
105+
parser.add_argument('--random_file', type=str, default='score_random_guess.json')
106+
args = parser.parse_args()
107+
108+
# args
109+
output_file = os.path.join(args.output_dir, args.output_file)
110+
111+
# # quick test
112+
# output_file = '../results/llava-llama-2-13b/output_llava_llama_2_13b.json'
113+
114+
# read json
115+
print(f'Reading {output_file}...')
116+
results = read_json(output_file)
117+
118+
# read ground truth
119+
print(f'Reading {args.gt_file}...')
120+
gts = read_json(args.gt_file)
121+
122+
# full pids
123+
full_pids = list(results.keys())
124+
if args.number > 0:
125+
full_pids = full_pids[:min(args.number, len(full_pids))]
126+
print('Number of testing problems:', len(full_pids))
127+
128+
## [1] Evaluate if the prediction is true or false
129+
print('\nEvaluating the predictions...')
130+
update_json_flag = False
131+
for pid in full_pids:
132+
problem = results[pid]
133+
# print(problem)
134+
135+
if args.rerun:
136+
if 'prediction' in problem:
137+
del problem['prediction']
138+
if 'true_false' in problem:
139+
del problem['true_false']
140+
141+
choices = problem['choices']
142+
question_type = problem['question_type']
143+
answer_type = problem['answer_type']
144+
precision = problem['precision']
145+
extraction = problem['extraction']
146+
147+
if 'answer' in problem:
148+
answer = problem['answer']
149+
else:
150+
answer = gts[pid]['answer']
151+
problem['answer'] = answer
152+
153+
# normalize the extracted answer to match the answer type
154+
prediction = normalize_extracted_answer(extraction, choices, question_type, answer_type, precision)
155+
156+
# verify the prediction is true or false
157+
true_false = safe_equal(prediction, answer)
158+
159+
# update the problem
160+
if 'true_false' not in problem:
161+
update_json_flag = True
162+
163+
elif true_false != problem['true_false']:
164+
update_json_flag = True
165+
166+
if 'prediction' not in problem:
167+
update_json_flag = True
168+
169+
elif prediction != problem['prediction']:
170+
update_json_flag = True
171+
172+
problem['prediction'] = prediction
173+
problem['true_false'] = true_false
174+
175+
# save the updated json
176+
if update_json_flag:
177+
print('\n!!!Some problems are updated.!!!')
178+
print(f'\nSaving {output_file}...')
179+
save_json(results, output_file)
180+
181+
## [2] Calculate the average accuracy
182+
total = len(full_pids)
183+
correct = 0
184+
for pid in full_pids:
185+
if results[pid]['true_false']:
186+
correct += 1
187+
accuracy = str(round(correct / total * 100, 2))
188+
print(f'\nCorrect: {correct}, Total: {total}, Accuracy: {accuracy}%')
189+
190+
scores = {'average': {'accuracy': accuracy, 'correct': correct, 'total': total}}
191+
192+
## [3] Calculate the fine-grained accuracy scores
193+
194+
# merge the 'metadata' attribute into the data
195+
for pid in results:
196+
results[pid].update(results[pid].pop('metadata'))
197+
198+
# convert the data to a pandas DataFrame
199+
df = pd.DataFrame(results).T
200+
201+
print(len(df))
202+
print('Number of test problems:', len(df))
203+
# assert len(df) == 1000 # Important!!!
204+
205+
# asign the target keys for evaluation
206+
target_keys = ['question_type', 'answer_type', 'language', 'source', 'category', 'task', 'context', 'grade',
207+
'skills']
208+
209+
for key in target_keys:
210+
print(f'\nType: [{key}]')
211+
# get the unique values of the key
212+
if key == 'skills':
213+
# the value is a list
214+
values = []
215+
for i in range(len(df)):
216+
values += df[key][i]
217+
values = list(set(values))
218+
else:
219+
values = df[key].unique()
220+
# print(values)
221+
222+
# calculate the accuracy for each value
223+
scores[key] = {}
224+
for value in values:
225+
correct, total, acc = get_acc_with_contion(df, key, value)
226+
if total > 0:
227+
print(f'[{value}]: {acc}% ({correct}/{total})')
228+
scores[key][value] = {'accuracy': acc, 'correct': correct, 'total': total}
229+
230+
# sort the scores by accuracy
231+
scores[key] = dict(sorted(scores[key].items(), key=lambda item: float(item[1]['accuracy']), reverse=True))
232+
233+
# save the scores
234+
scores_file = os.path.join(args.output_dir, args.score_file)
235+
print(f'\nSaving {scores_file}...')
236+
save_json(scores, scores_file)
237+
print('\nDone!')
238+
239+
# [4] Calculate the score gains over random guess
240+
if args.caculate_gain:
241+
random_file = os.path.join(args.output_dir, args.random_file)
242+
random_scores = json.load(open(random_file))
243+
244+
print('\nCalculating the score gains...')
245+
for key in scores:
246+
if key == 'average':
247+
gain = round(float(scores[key]['accuracy']) - float(random_scores[key]['accuracy']), 2)
248+
scores[key]['acc_gain'] = gain
249+
else:
250+
for sub_key in scores[key]:
251+
gain = round(
252+
float(scores[key][sub_key]['accuracy']) - float(random_scores[key][sub_key]['accuracy']), 2)
253+
scores[key][sub_key]['acc_gain'] = str(gain)
254+
255+
# save the score gains
256+
print(f'\nSaving {scores_file}...')
257+
save_json(scores, scores_file)
258+
print('\nDone!')

0 commit comments

Comments
 (0)