Skip to content

Commit 966de12

Browse files
G-z-wczczup
authored andcommitted
add mini_internvl (#633)
* add mini-internvl * Update README.md * Update README.md * Update internvl2_2b_internlm2_1_8b_dynamic_res_finetune_medical.sh
1 parent 0cda495 commit 966de12

21 files changed

+2601
-0
lines changed
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
import argparse
2+
import itertools
3+
import json
4+
import os
5+
import random
6+
import time
7+
from functools import partial
8+
9+
import torch
10+
from datasets import concatenate_datasets, load_dataset
11+
from internvl.model.internvl_chat import InternVLChatModel
12+
from internvl.train.dataset import build_transform, dynamic_preprocess
13+
from torch.utils.data import Dataset
14+
from tqdm import tqdm
15+
from transformers import AutoTokenizer
16+
from PIL import Image
17+
import re
18+
19+
ds_collections = {
20+
'DriveLM_val': {
21+
'root': 'InternVL-Domain-Adaptation-Data/val/drivelm_val.jsonl',
22+
'max_new_tokens': 200,
23+
'min_new_tokens': 1,
24+
'split': 'validation',
25+
"image_root":"InternVL-Domain-Adaptation-Data/images/drivelm/stitch",
26+
}
27+
}
28+
29+
30+
31+
def post_process(pred):
32+
pred = pred.strip()
33+
pattern = r"<c[^,]*,\s*[^,]*,\s*\[\s*-?[0-9]*\.?[0-9]+\s*,\s*-?[0-9]*\.?[0-9]+\s*\]\s*>"
34+
mapping={"CAM_FRONT_LEFT":[0,0],"CAM_FRONT":[1,0],"CAM_FRONT_RIGHT":[2,0],"CAM_BACK_LEFT":[0,1],"CAM_BACK":[1,1],"CAM_BACK_RIGHT":[2,1]}
35+
patch_size = 448
36+
width = patch_size * 2
37+
height = patch_size
38+
whole_img_width=width*3
39+
whole_img_height=height*2
40+
matches = re.findall(pattern, pred)
41+
for object_id in matches:
42+
43+
object_id_c = object_id.replace("<","").replace(">","")
44+
try:
45+
ctag = object_id_c.split(",")[0]
46+
cxcy = json.loads(",".join(object_id_c.split(",")[2:]))
47+
cam = object_id_c.split(",")[1]
48+
if cam in mapping:
49+
mx,my=mapping[cam]
50+
# old_wide,old_height = images_size[cam]
51+
old_wide,old_height = 1600, 900
52+
cx ,cy = cxcy
53+
cx = (cx / 1000) * whole_img_width
54+
cy = (cy/1000) * whole_img_height
55+
cx -= mx*width
56+
cy -= my*height
57+
cx = cx/width * old_wide
58+
cy = cy/height * old_height
59+
# cx =max(0,min(old_wide,cx))
60+
# cy =max(0,min(old_height,cy))
61+
cx =round(max(0,min(old_wide,cx)),1)
62+
cy =round(max(0,min(old_height,cy)),1)
63+
new_object_id = f"<{ctag},{cam},{cx},{cy}>"
64+
65+
pred = pred.replace(object_id,new_object_id)
66+
except Exception as e:
67+
print(e)
68+
return pred
69+
70+
def collate_fn(batches, tokenizer):
71+
pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
72+
questions = [_['question'] for _ in batches]
73+
questions_old = [_['question_old'] for _ in batches]
74+
answers = [_['answer'] for _ in batches]
75+
data_ids = [_['data_id'] for _ in batches]
76+
# images_sizes = [_['images_size'] for _ in batches]
77+
return pixel_values, questions_old,questions, answers, data_ids
78+
79+
class DriveLMDataset(torch.utils.data.Dataset):
80+
81+
def __init__(self, root, split, prompt, image_path, input_size=224, dynamic_image_size=False,
82+
use_thumbnail=False, max_num=6,):
83+
# run for each subject
84+
85+
with open(root,"r") as f:
86+
self.data = [json.loads(line) for line in f.readlines()]
87+
# data_val = json.load(f)
88+
# merge all dataset
89+
# self.data = concatenate_datasets(sub_dataset_list)
90+
self.prompt = prompt
91+
self.input_size = input_size
92+
self.dynamic_image_size = dynamic_image_size
93+
self.use_thumbnail = use_thumbnail
94+
self.max_num = max_num
95+
self.transform = build_transform(is_train=False, input_size=input_size)
96+
self.image_path =image_path
97+
98+
# with open(image_meta,"r") as f:
99+
# self.image_meta = json.load(f)
100+
101+
def __len__(self):
102+
return len(self.data)
103+
104+
def __getitem__(self, idx):
105+
106+
data = self.data[idx]
107+
data_id = data['id']
108+
question = data["conversations"][0]["value"].strip()
109+
question_old = data["question_old"]
110+
image_file = os.path.join(self.image_path,data['image'])
111+
image = Image.open(image_file).convert("RGB")
112+
# question_type = data['question_type']
113+
114+
# choices = eval(data['options'])
115+
answer = data["conversations"][1]["value"].strip()
116+
117+
if self.dynamic_image_size:
118+
# images = []
119+
120+
pil_image = dynamic_preprocess(image, image_size=self.input_size,
121+
use_thumbnail=self.use_thumbnail,
122+
max_num=self.max_num)
123+
images = pil_image
124+
else:
125+
images = [image]
126+
pixel_values = [self.transform(image) for image in images]
127+
pixel_values = torch.stack(pixel_values)
128+
129+
# image_id = os.path.basename(image_file).split(".")[0]
130+
# images_size = self.image_meta[image_id]["images_size"]
131+
132+
133+
return {
134+
"question_old":question_old,
135+
'question': question,
136+
'pixel_values': pixel_values,
137+
# 'images_size':images_size,
138+
'answer': answer,
139+
'data_id': data_id
140+
}
141+
142+
143+
class InferenceSampler(torch.utils.data.sampler.Sampler):
144+
145+
def __init__(self, size):
146+
self._size = int(size)
147+
assert size > 0
148+
self._rank = torch.distributed.get_rank()
149+
self._world_size = torch.distributed.get_world_size()
150+
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
151+
152+
@staticmethod
153+
def _get_local_indices(total_size, world_size, rank):
154+
shard_size = total_size // world_size
155+
left = total_size % world_size
156+
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
157+
158+
begin = sum(shard_sizes[:rank])
159+
end = min(sum(shard_sizes[:rank + 1]), total_size)
160+
return range(begin, end)
161+
162+
def __iter__(self):
163+
yield from self._local_indices
164+
165+
def __len__(self):
166+
return len(self._local_indices)
167+
168+
def evaluate_chat_model():
169+
170+
random.seed(args.seed)
171+
prompt = None
172+
for ds_name in args.datasets:
173+
dataset = DriveLMDataset(
174+
root=ds_collections[ds_name]['root'],
175+
split=ds_collections[ds_name]['split'],
176+
prompt=prompt,
177+
image_path = ds_collections[ds_name]["image_root"],
178+
# image_meta = ds_collections[ds_name]["image_meta"],
179+
input_size=image_size,
180+
dynamic_image_size=args.dynamic,
181+
use_thumbnail=use_thumbnail,
182+
max_num=args.max_num
183+
)
184+
dataloader = torch.utils.data.DataLoader(
185+
dataset=dataset,
186+
sampler=InferenceSampler(len(dataset)),
187+
batch_size=args.batch_size,
188+
num_workers=args.num_workers,
189+
pin_memory=True,
190+
drop_last=False,
191+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
192+
)
193+
194+
outputs = []
195+
for _, (pixel_values, questions_old, questions, answers, data_ids) in tqdm(enumerate(dataloader)):
196+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
197+
generation_config = dict(
198+
num_beams=args.num_beams,
199+
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
200+
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
201+
do_sample=True if args.temperature > 0 else False,
202+
temperature=args.temperature,
203+
)
204+
pred = model.chat(
205+
tokenizer=tokenizer,
206+
pixel_values=pixel_values,
207+
question=questions[0],
208+
generation_config=generation_config
209+
)
210+
211+
# preds = [pred]
212+
# if len(options[0]) == 0:
213+
# preds = [pred]
214+
# else:
215+
preds = [post_process(pred)]
216+
217+
for question, pred, answer, data_id,question_old in zip(questions, preds, answers, data_ids,questions_old):
218+
outputs.append({
219+
'question': question_old,
220+
'answer': pred,
221+
'gt_answers': answer,
222+
'id': data_id
223+
})
224+
225+
torch.distributed.barrier()
226+
227+
world_size = torch.distributed.get_world_size()
228+
merged_outputs = [None for _ in range(world_size)]
229+
torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
230+
231+
merged_outputs = [json.loads(_) for _ in merged_outputs]
232+
merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
233+
234+
if torch.distributed.get_rank() == 0:
235+
236+
print(f'Evaluating {ds_name} ...')
237+
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
238+
results_file = f'{ds_name}_{time_prefix}.json'
239+
output_path = os.path.join(args.out_dir, results_file)
240+
241+
with open(output_path, 'w') as f:
242+
json.dump(merged_outputs, f, indent=4)
243+
print('Results saved to {}'.format(output_path))
244+
245+
246+
247+
if __name__ == '__main__':
248+
parser = argparse.ArgumentParser()
249+
parser.add_argument('--checkpoint', type=str, default='')
250+
parser.add_argument('--datasets', type=str, default='MMMU_dev')
251+
parser.add_argument('--batch-size', type=int, default=1)
252+
parser.add_argument('--num-workers', type=int, default=1)
253+
parser.add_argument('--num-beams', type=int, default=5)
254+
parser.add_argument('--temperature', type=float, default=0.0)
255+
parser.add_argument('--out-dir', type=str, default='results')
256+
parser.add_argument('--seed', type=int, default=0)
257+
parser.add_argument('--dynamic', action='store_true')
258+
parser.add_argument('--max-num', type=int, default=12)
259+
parser.add_argument('--load-in-8bit', action='store_true')
260+
parser.add_argument('--auto', action='store_true')
261+
args = parser.parse_args()
262+
263+
if not os.path.exists(args.out_dir):
264+
os.makedirs(args.out_dir)
265+
266+
args.datasets = args.datasets.split(',')
267+
print('datasets:', args.datasets)
268+
assert args.batch_size == 1, 'Only batch size 1 is supported'
269+
270+
torch.distributed.init_process_group(
271+
backend='nccl',
272+
world_size=int(os.getenv('WORLD_SIZE', '1')),
273+
rank=int(os.getenv('RANK', '0')),
274+
)
275+
276+
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
277+
278+
if args.auto:
279+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
280+
kwargs = {'device_map': 'auto'} if args.auto else {}
281+
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
282+
model = InternVLChatModel.from_pretrained(
283+
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
284+
load_in_8bit=args.load_in_8bit, **kwargs).eval()
285+
if not args.load_in_8bit and not args.auto:
286+
model = model.cuda()
287+
image_size = model.config.force_image_size or model.config.vision_config.image_size
288+
use_thumbnail = model.config.use_thumbnail
289+
290+
total_params = sum(p.numel() for p in model.parameters()) / 1e9
291+
if total_params > 20 or args.dynamic:
292+
args.num_beams = 1
293+
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
294+
else:
295+
print(f'[test] total_params: {total_params}B')
296+
print(f'[test] image_size: {image_size}')
297+
print(f'[test] template: {model.config.template}')
298+
print(f'[test] dynamic_image_size: {args.dynamic}')
299+
print(f'[test] use_thumbnail: {use_thumbnail}')
300+
print(f'[test] max_num: {args.max_num}')
301+
302+
evaluate_chat_model()

0 commit comments

Comments
 (0)