|
| 1 | +import argparse |
| 2 | +import sys |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +# !pip install python-Levenshtein |
| 6 | +from Levenshtein import distance |
| 7 | +from utilities import * |
| 8 | + |
| 9 | + |
| 10 | +def get_most_similar(prediction, choices): |
| 11 | + """ |
| 12 | + Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction |
| 13 | + """ |
| 14 | + distances = [distance(prediction, choice) for choice in choices] |
| 15 | + ind = distances.index(min(distances)) |
| 16 | + return choices[ind] |
| 17 | + # return min(choices, key=lambda choice: distance(prediction, choice)) |
| 18 | + |
| 19 | + |
| 20 | +def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision): |
| 21 | + """ |
| 22 | + Normalize the extracted answer to match the answer type |
| 23 | + """ |
| 24 | + if question_type == 'multi_choice': |
| 25 | + # make sure the extraction is a string |
| 26 | + if isinstance(extraction, str): |
| 27 | + extraction = extraction.strip() |
| 28 | + else: |
| 29 | + try: |
| 30 | + extraction = str(extraction) |
| 31 | + except: |
| 32 | + extraction = '' |
| 33 | + |
| 34 | + # extract "A" from "(A) text" |
| 35 | + letter = re.findall(r'\(([a-zA-Z])\)', extraction) |
| 36 | + if len(letter) > 0: |
| 37 | + extraction = letter[0].upper() |
| 38 | + |
| 39 | + options = [chr(ord('A') + i) for i in range(len(choices))] |
| 40 | + |
| 41 | + if extraction in options: |
| 42 | + # convert option letter to text, e.g. "A" -> "text" |
| 43 | + ind = options.index(extraction) |
| 44 | + extraction = choices[ind] |
| 45 | + else: |
| 46 | + # select the most similar option |
| 47 | + extraction = get_most_similar(extraction, choices) |
| 48 | + assert extraction in choices |
| 49 | + |
| 50 | + elif answer_type == 'integer': |
| 51 | + try: |
| 52 | + extraction = str(int(float(extraction))) |
| 53 | + except: |
| 54 | + extraction = None |
| 55 | + |
| 56 | + elif answer_type == 'float': |
| 57 | + try: |
| 58 | + extraction = str(round(float(extraction), precision)) |
| 59 | + except: |
| 60 | + extraction = None |
| 61 | + |
| 62 | + elif answer_type == 'list': |
| 63 | + try: |
| 64 | + extraction = str(extraction) |
| 65 | + except: |
| 66 | + extraction = None |
| 67 | + |
| 68 | + return extraction |
| 69 | + |
| 70 | + |
| 71 | +def safe_equal(prediction, answer): |
| 72 | + """ |
| 73 | + Check if the prediction is equal to the answer, even if they are of different types |
| 74 | + """ |
| 75 | + try: |
| 76 | + if prediction == answer: |
| 77 | + return True |
| 78 | + return False |
| 79 | + except Exception as e: |
| 80 | + print(e) |
| 81 | + return False |
| 82 | + |
| 83 | + |
| 84 | +def get_acc_with_contion(res_pd, key, value): |
| 85 | + if key == 'skills': |
| 86 | + # if value in res_pd[key]: |
| 87 | + total_pd = res_pd[res_pd[key].apply(lambda x: value in x)] |
| 88 | + else: |
| 89 | + total_pd = res_pd[res_pd[key] == value] |
| 90 | + |
| 91 | + correct_pd = total_pd[total_pd['true_false'] == True] # noqa: E712 |
| 92 | + acc = '{:.2f}'.format(len(correct_pd) / len(total_pd) * 100) |
| 93 | + return len(correct_pd), len(total_pd), acc |
| 94 | + |
| 95 | + |
| 96 | +if __name__ == '__main__': |
| 97 | + parser = argparse.ArgumentParser() |
| 98 | + parser.add_argument('--output_dir', type=str, default='./results') |
| 99 | + parser.add_argument('--output_file', type=str, default='output.json') |
| 100 | + parser.add_argument('--score_file', type=str, default='scores.json') |
| 101 | + parser.add_argument('--gt_file', type=str, default='./data/MathVista/annot_testmini.json', help='ground truth file') |
| 102 | + parser.add_argument('--number', type=int, default=-1, help='number of problems to run') |
| 103 | + parser.add_argument('--rerun', action='store_true', help='rerun the evaluation') |
| 104 | + parser.add_argument('--caculate_gain', action='store_true', help='caculate the socre gains over random guess') |
| 105 | + parser.add_argument('--random_file', type=str, default='score_random_guess.json') |
| 106 | + args = parser.parse_args() |
| 107 | + |
| 108 | + # args |
| 109 | + output_file = os.path.join(args.output_dir, args.output_file) |
| 110 | + |
| 111 | + # # quick test |
| 112 | + # output_file = '../results/llava-llama-2-13b/output_llava_llama_2_13b.json' |
| 113 | + |
| 114 | + # read json |
| 115 | + print(f'Reading {output_file}...') |
| 116 | + results = read_json(output_file) |
| 117 | + |
| 118 | + # read ground truth |
| 119 | + print(f'Reading {args.gt_file}...') |
| 120 | + gts = read_json(args.gt_file) |
| 121 | + |
| 122 | + # full pids |
| 123 | + full_pids = list(results.keys()) |
| 124 | + if args.number > 0: |
| 125 | + full_pids = full_pids[:min(args.number, len(full_pids))] |
| 126 | + print('Number of testing problems:', len(full_pids)) |
| 127 | + |
| 128 | + ## [1] Evaluate if the prediction is true or false |
| 129 | + print('\nEvaluating the predictions...') |
| 130 | + update_json_flag = False |
| 131 | + for pid in full_pids: |
| 132 | + problem = results[pid] |
| 133 | + # print(problem) |
| 134 | + |
| 135 | + if args.rerun: |
| 136 | + if 'prediction' in problem: |
| 137 | + del problem['prediction'] |
| 138 | + if 'true_false' in problem: |
| 139 | + del problem['true_false'] |
| 140 | + |
| 141 | + choices = problem['choices'] |
| 142 | + question_type = problem['question_type'] |
| 143 | + answer_type = problem['answer_type'] |
| 144 | + precision = problem['precision'] |
| 145 | + extraction = problem['extraction'] |
| 146 | + |
| 147 | + if 'answer' in problem: |
| 148 | + answer = problem['answer'] |
| 149 | + else: |
| 150 | + answer = gts[pid]['answer'] |
| 151 | + problem['answer'] = answer |
| 152 | + |
| 153 | + # normalize the extracted answer to match the answer type |
| 154 | + prediction = normalize_extracted_answer(extraction, choices, question_type, answer_type, precision) |
| 155 | + |
| 156 | + # verify the prediction is true or false |
| 157 | + true_false = safe_equal(prediction, answer) |
| 158 | + |
| 159 | + # update the problem |
| 160 | + if 'true_false' not in problem: |
| 161 | + update_json_flag = True |
| 162 | + |
| 163 | + elif true_false != problem['true_false']: |
| 164 | + update_json_flag = True |
| 165 | + |
| 166 | + if 'prediction' not in problem: |
| 167 | + update_json_flag = True |
| 168 | + |
| 169 | + elif prediction != problem['prediction']: |
| 170 | + update_json_flag = True |
| 171 | + |
| 172 | + problem['prediction'] = prediction |
| 173 | + problem['true_false'] = true_false |
| 174 | + |
| 175 | + # save the updated json |
| 176 | + if update_json_flag: |
| 177 | + print('\n!!!Some problems are updated.!!!') |
| 178 | + print(f'\nSaving {output_file}...') |
| 179 | + save_json(results, output_file) |
| 180 | + |
| 181 | + ## [2] Calculate the average accuracy |
| 182 | + total = len(full_pids) |
| 183 | + correct = 0 |
| 184 | + for pid in full_pids: |
| 185 | + if results[pid]['true_false']: |
| 186 | + correct += 1 |
| 187 | + accuracy = str(round(correct / total * 100, 2)) |
| 188 | + print(f'\nCorrect: {correct}, Total: {total}, Accuracy: {accuracy}%') |
| 189 | + |
| 190 | + scores = {'average': {'accuracy': accuracy, 'correct': correct, 'total': total}} |
| 191 | + |
| 192 | + ## [3] Calculate the fine-grained accuracy scores |
| 193 | + |
| 194 | + # merge the 'metadata' attribute into the data |
| 195 | + for pid in results: |
| 196 | + results[pid].update(results[pid].pop('metadata')) |
| 197 | + |
| 198 | + # convert the data to a pandas DataFrame |
| 199 | + df = pd.DataFrame(results).T |
| 200 | + |
| 201 | + print(len(df)) |
| 202 | + print('Number of test problems:', len(df)) |
| 203 | + # assert len(df) == 1000 # Important!!! |
| 204 | + |
| 205 | + # asign the target keys for evaluation |
| 206 | + target_keys = ['question_type', 'answer_type', 'language', 'source', 'category', 'task', 'context', 'grade', |
| 207 | + 'skills'] |
| 208 | + |
| 209 | + for key in target_keys: |
| 210 | + print(f'\nType: [{key}]') |
| 211 | + # get the unique values of the key |
| 212 | + if key == 'skills': |
| 213 | + # the value is a list |
| 214 | + values = [] |
| 215 | + for i in range(len(df)): |
| 216 | + values += df[key][i] |
| 217 | + values = list(set(values)) |
| 218 | + else: |
| 219 | + values = df[key].unique() |
| 220 | + # print(values) |
| 221 | + |
| 222 | + # calculate the accuracy for each value |
| 223 | + scores[key] = {} |
| 224 | + for value in values: |
| 225 | + correct, total, acc = get_acc_with_contion(df, key, value) |
| 226 | + if total > 0: |
| 227 | + print(f'[{value}]: {acc}% ({correct}/{total})') |
| 228 | + scores[key][value] = {'accuracy': acc, 'correct': correct, 'total': total} |
| 229 | + |
| 230 | + # sort the scores by accuracy |
| 231 | + scores[key] = dict(sorted(scores[key].items(), key=lambda item: float(item[1]['accuracy']), reverse=True)) |
| 232 | + |
| 233 | + # save the scores |
| 234 | + scores_file = os.path.join(args.output_dir, args.score_file) |
| 235 | + print(f'\nSaving {scores_file}...') |
| 236 | + save_json(scores, scores_file) |
| 237 | + print('\nDone!') |
| 238 | + |
| 239 | + # [4] Calculate the score gains over random guess |
| 240 | + if args.caculate_gain: |
| 241 | + random_file = os.path.join(args.output_dir, args.random_file) |
| 242 | + random_scores = json.load(open(random_file)) |
| 243 | + |
| 244 | + print('\nCalculating the score gains...') |
| 245 | + for key in scores: |
| 246 | + if key == 'average': |
| 247 | + gain = round(float(scores[key]['accuracy']) - float(random_scores[key]['accuracy']), 2) |
| 248 | + scores[key]['acc_gain'] = gain |
| 249 | + else: |
| 250 | + for sub_key in scores[key]: |
| 251 | + gain = round( |
| 252 | + float(scores[key][sub_key]['accuracy']) - float(random_scores[key][sub_key]['accuracy']), 2) |
| 253 | + scores[key][sub_key]['acc_gain'] = str(gain) |
| 254 | + |
| 255 | + # save the score gains |
| 256 | + print(f'\nSaving {scores_file}...') |
| 257 | + save_json(scores, scores_file) |
| 258 | + print('\nDone!') |
0 commit comments