|
| 1 | +import argparse |
| 2 | +import itertools |
| 3 | +import json |
| 4 | +import os |
| 5 | +import random |
| 6 | +import time |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import torch |
| 10 | +from datasets import concatenate_datasets, load_dataset |
| 11 | +from internvl.model.internvl_chat import InternVLChatModel |
| 12 | +from internvl.train.dataset import build_transform, dynamic_preprocess |
| 13 | +from torch.utils.data import Dataset |
| 14 | +from tqdm import tqdm |
| 15 | +from transformers import AutoTokenizer |
| 16 | +from PIL import Image |
| 17 | +import re |
| 18 | + |
| 19 | +ds_collections = { |
| 20 | + 'DriveLM_val': { |
| 21 | + 'root': 'InternVL-Domain-Adaptation-Data/val/drivelm_val.jsonl', |
| 22 | + 'max_new_tokens': 200, |
| 23 | + 'min_new_tokens': 1, |
| 24 | + 'split': 'validation', |
| 25 | + "image_root":"InternVL-Domain-Adaptation-Data/images/drivelm/stitch", |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | + |
| 31 | +def post_process(pred): |
| 32 | + pred = pred.strip() |
| 33 | + pattern = r"<c[^,]*,\s*[^,]*,\s*\[\s*-?[0-9]*\.?[0-9]+\s*,\s*-?[0-9]*\.?[0-9]+\s*\]\s*>" |
| 34 | + mapping={"CAM_FRONT_LEFT":[0,0],"CAM_FRONT":[1,0],"CAM_FRONT_RIGHT":[2,0],"CAM_BACK_LEFT":[0,1],"CAM_BACK":[1,1],"CAM_BACK_RIGHT":[2,1]} |
| 35 | + patch_size = 448 |
| 36 | + width = patch_size * 2 |
| 37 | + height = patch_size |
| 38 | + whole_img_width=width*3 |
| 39 | + whole_img_height=height*2 |
| 40 | + matches = re.findall(pattern, pred) |
| 41 | + for object_id in matches: |
| 42 | + |
| 43 | + object_id_c = object_id.replace("<","").replace(">","") |
| 44 | + try: |
| 45 | + ctag = object_id_c.split(",")[0] |
| 46 | + cxcy = json.loads(",".join(object_id_c.split(",")[2:])) |
| 47 | + cam = object_id_c.split(",")[1] |
| 48 | + if cam in mapping: |
| 49 | + mx,my=mapping[cam] |
| 50 | + # old_wide,old_height = images_size[cam] |
| 51 | + old_wide,old_height = 1600, 900 |
| 52 | + cx ,cy = cxcy |
| 53 | + cx = (cx / 1000) * whole_img_width |
| 54 | + cy = (cy/1000) * whole_img_height |
| 55 | + cx -= mx*width |
| 56 | + cy -= my*height |
| 57 | + cx = cx/width * old_wide |
| 58 | + cy = cy/height * old_height |
| 59 | + # cx =max(0,min(old_wide,cx)) |
| 60 | + # cy =max(0,min(old_height,cy)) |
| 61 | + cx =round(max(0,min(old_wide,cx)),1) |
| 62 | + cy =round(max(0,min(old_height,cy)),1) |
| 63 | + new_object_id = f"<{ctag},{cam},{cx},{cy}>" |
| 64 | + |
| 65 | + pred = pred.replace(object_id,new_object_id) |
| 66 | + except Exception as e: |
| 67 | + print(e) |
| 68 | + return pred |
| 69 | + |
| 70 | +def collate_fn(batches, tokenizer): |
| 71 | + pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) |
| 72 | + questions = [_['question'] for _ in batches] |
| 73 | + questions_old = [_['question_old'] for _ in batches] |
| 74 | + answers = [_['answer'] for _ in batches] |
| 75 | + data_ids = [_['data_id'] for _ in batches] |
| 76 | + # images_sizes = [_['images_size'] for _ in batches] |
| 77 | + return pixel_values, questions_old,questions, answers, data_ids |
| 78 | + |
| 79 | +class DriveLMDataset(torch.utils.data.Dataset): |
| 80 | + |
| 81 | + def __init__(self, root, split, prompt, image_path, input_size=224, dynamic_image_size=False, |
| 82 | + use_thumbnail=False, max_num=6,): |
| 83 | + # run for each subject |
| 84 | + |
| 85 | + with open(root,"r") as f: |
| 86 | + self.data = [json.loads(line) for line in f.readlines()] |
| 87 | + # data_val = json.load(f) |
| 88 | + # merge all dataset |
| 89 | + # self.data = concatenate_datasets(sub_dataset_list) |
| 90 | + self.prompt = prompt |
| 91 | + self.input_size = input_size |
| 92 | + self.dynamic_image_size = dynamic_image_size |
| 93 | + self.use_thumbnail = use_thumbnail |
| 94 | + self.max_num = max_num |
| 95 | + self.transform = build_transform(is_train=False, input_size=input_size) |
| 96 | + self.image_path =image_path |
| 97 | + |
| 98 | + # with open(image_meta,"r") as f: |
| 99 | + # self.image_meta = json.load(f) |
| 100 | + |
| 101 | + def __len__(self): |
| 102 | + return len(self.data) |
| 103 | + |
| 104 | + def __getitem__(self, idx): |
| 105 | + |
| 106 | + data = self.data[idx] |
| 107 | + data_id = data['id'] |
| 108 | + question = data["conversations"][0]["value"].strip() |
| 109 | + question_old = data["question_old"] |
| 110 | + image_file = os.path.join(self.image_path,data['image']) |
| 111 | + image = Image.open(image_file).convert("RGB") |
| 112 | + # question_type = data['question_type'] |
| 113 | + |
| 114 | + # choices = eval(data['options']) |
| 115 | + answer = data["conversations"][1]["value"].strip() |
| 116 | + |
| 117 | + if self.dynamic_image_size: |
| 118 | + # images = [] |
| 119 | + |
| 120 | + pil_image = dynamic_preprocess(image, image_size=self.input_size, |
| 121 | + use_thumbnail=self.use_thumbnail, |
| 122 | + max_num=self.max_num) |
| 123 | + images = pil_image |
| 124 | + else: |
| 125 | + images = [image] |
| 126 | + pixel_values = [self.transform(image) for image in images] |
| 127 | + pixel_values = torch.stack(pixel_values) |
| 128 | + |
| 129 | + # image_id = os.path.basename(image_file).split(".")[0] |
| 130 | + # images_size = self.image_meta[image_id]["images_size"] |
| 131 | + |
| 132 | + |
| 133 | + return { |
| 134 | + "question_old":question_old, |
| 135 | + 'question': question, |
| 136 | + 'pixel_values': pixel_values, |
| 137 | + # 'images_size':images_size, |
| 138 | + 'answer': answer, |
| 139 | + 'data_id': data_id |
| 140 | + } |
| 141 | + |
| 142 | + |
| 143 | +class InferenceSampler(torch.utils.data.sampler.Sampler): |
| 144 | + |
| 145 | + def __init__(self, size): |
| 146 | + self._size = int(size) |
| 147 | + assert size > 0 |
| 148 | + self._rank = torch.distributed.get_rank() |
| 149 | + self._world_size = torch.distributed.get_world_size() |
| 150 | + self._local_indices = self._get_local_indices(size, self._world_size, self._rank) |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def _get_local_indices(total_size, world_size, rank): |
| 154 | + shard_size = total_size // world_size |
| 155 | + left = total_size % world_size |
| 156 | + shard_sizes = [shard_size + int(r < left) for r in range(world_size)] |
| 157 | + |
| 158 | + begin = sum(shard_sizes[:rank]) |
| 159 | + end = min(sum(shard_sizes[:rank + 1]), total_size) |
| 160 | + return range(begin, end) |
| 161 | + |
| 162 | + def __iter__(self): |
| 163 | + yield from self._local_indices |
| 164 | + |
| 165 | + def __len__(self): |
| 166 | + return len(self._local_indices) |
| 167 | + |
| 168 | +def evaluate_chat_model(): |
| 169 | + |
| 170 | + random.seed(args.seed) |
| 171 | + prompt = None |
| 172 | + for ds_name in args.datasets: |
| 173 | + dataset = DriveLMDataset( |
| 174 | + root=ds_collections[ds_name]['root'], |
| 175 | + split=ds_collections[ds_name]['split'], |
| 176 | + prompt=prompt, |
| 177 | + image_path = ds_collections[ds_name]["image_root"], |
| 178 | + # image_meta = ds_collections[ds_name]["image_meta"], |
| 179 | + input_size=image_size, |
| 180 | + dynamic_image_size=args.dynamic, |
| 181 | + use_thumbnail=use_thumbnail, |
| 182 | + max_num=args.max_num |
| 183 | + ) |
| 184 | + dataloader = torch.utils.data.DataLoader( |
| 185 | + dataset=dataset, |
| 186 | + sampler=InferenceSampler(len(dataset)), |
| 187 | + batch_size=args.batch_size, |
| 188 | + num_workers=args.num_workers, |
| 189 | + pin_memory=True, |
| 190 | + drop_last=False, |
| 191 | + collate_fn=partial(collate_fn, tokenizer=tokenizer), |
| 192 | + ) |
| 193 | + |
| 194 | + outputs = [] |
| 195 | + for _, (pixel_values, questions_old, questions, answers, data_ids) in tqdm(enumerate(dataloader)): |
| 196 | + pixel_values = pixel_values.to(torch.bfloat16).cuda() |
| 197 | + generation_config = dict( |
| 198 | + num_beams=args.num_beams, |
| 199 | + max_new_tokens=ds_collections[ds_name]['max_new_tokens'], |
| 200 | + min_new_tokens=ds_collections[ds_name]['min_new_tokens'], |
| 201 | + do_sample=True if args.temperature > 0 else False, |
| 202 | + temperature=args.temperature, |
| 203 | + ) |
| 204 | + pred = model.chat( |
| 205 | + tokenizer=tokenizer, |
| 206 | + pixel_values=pixel_values, |
| 207 | + question=questions[0], |
| 208 | + generation_config=generation_config |
| 209 | + ) |
| 210 | + |
| 211 | + # preds = [pred] |
| 212 | + # if len(options[0]) == 0: |
| 213 | + # preds = [pred] |
| 214 | + # else: |
| 215 | + preds = [post_process(pred)] |
| 216 | + |
| 217 | + for question, pred, answer, data_id,question_old in zip(questions, preds, answers, data_ids,questions_old): |
| 218 | + outputs.append({ |
| 219 | + 'question': question_old, |
| 220 | + 'answer': pred, |
| 221 | + 'gt_answers': answer, |
| 222 | + 'id': data_id |
| 223 | + }) |
| 224 | + |
| 225 | + torch.distributed.barrier() |
| 226 | + |
| 227 | + world_size = torch.distributed.get_world_size() |
| 228 | + merged_outputs = [None for _ in range(world_size)] |
| 229 | + torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) |
| 230 | + |
| 231 | + merged_outputs = [json.loads(_) for _ in merged_outputs] |
| 232 | + merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] |
| 233 | + |
| 234 | + if torch.distributed.get_rank() == 0: |
| 235 | + |
| 236 | + print(f'Evaluating {ds_name} ...') |
| 237 | + time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) |
| 238 | + results_file = f'{ds_name}_{time_prefix}.json' |
| 239 | + output_path = os.path.join(args.out_dir, results_file) |
| 240 | + |
| 241 | + with open(output_path, 'w') as f: |
| 242 | + json.dump(merged_outputs, f, indent=4) |
| 243 | + print('Results saved to {}'.format(output_path)) |
| 244 | + |
| 245 | + |
| 246 | + |
| 247 | +if __name__ == '__main__': |
| 248 | + parser = argparse.ArgumentParser() |
| 249 | + parser.add_argument('--checkpoint', type=str, default='') |
| 250 | + parser.add_argument('--datasets', type=str, default='MMMU_dev') |
| 251 | + parser.add_argument('--batch-size', type=int, default=1) |
| 252 | + parser.add_argument('--num-workers', type=int, default=1) |
| 253 | + parser.add_argument('--num-beams', type=int, default=5) |
| 254 | + parser.add_argument('--temperature', type=float, default=0.0) |
| 255 | + parser.add_argument('--out-dir', type=str, default='results') |
| 256 | + parser.add_argument('--seed', type=int, default=0) |
| 257 | + parser.add_argument('--dynamic', action='store_true') |
| 258 | + parser.add_argument('--max-num', type=int, default=12) |
| 259 | + parser.add_argument('--load-in-8bit', action='store_true') |
| 260 | + parser.add_argument('--auto', action='store_true') |
| 261 | + args = parser.parse_args() |
| 262 | + |
| 263 | + if not os.path.exists(args.out_dir): |
| 264 | + os.makedirs(args.out_dir) |
| 265 | + |
| 266 | + args.datasets = args.datasets.split(',') |
| 267 | + print('datasets:', args.datasets) |
| 268 | + assert args.batch_size == 1, 'Only batch size 1 is supported' |
| 269 | + |
| 270 | + torch.distributed.init_process_group( |
| 271 | + backend='nccl', |
| 272 | + world_size=int(os.getenv('WORLD_SIZE', '1')), |
| 273 | + rank=int(os.getenv('RANK', '0')), |
| 274 | + ) |
| 275 | + |
| 276 | + torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) |
| 277 | + |
| 278 | + if args.auto: |
| 279 | + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
| 280 | + kwargs = {'device_map': 'auto'} if args.auto else {} |
| 281 | + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) |
| 282 | + model = InternVLChatModel.from_pretrained( |
| 283 | + args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, |
| 284 | + load_in_8bit=args.load_in_8bit, **kwargs).eval() |
| 285 | + if not args.load_in_8bit and not args.auto: |
| 286 | + model = model.cuda() |
| 287 | + image_size = model.config.force_image_size or model.config.vision_config.image_size |
| 288 | + use_thumbnail = model.config.use_thumbnail |
| 289 | + |
| 290 | + total_params = sum(p.numel() for p in model.parameters()) / 1e9 |
| 291 | + if total_params > 20 or args.dynamic: |
| 292 | + args.num_beams = 1 |
| 293 | + print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') |
| 294 | + else: |
| 295 | + print(f'[test] total_params: {total_params}B') |
| 296 | + print(f'[test] image_size: {image_size}') |
| 297 | + print(f'[test] template: {model.config.template}') |
| 298 | + print(f'[test] dynamic_image_size: {args.dynamic}') |
| 299 | + print(f'[test] use_thumbnail: {use_thumbnail}') |
| 300 | + print(f'[test] max_num: {args.max_num}') |
| 301 | + |
| 302 | + evaluate_chat_model() |
0 commit comments