@@ -279,7 +279,7 @@ def generate_stream(self, params):
279
279
max_input_tile_temp = []
280
280
for image_str in message ['image' ]:
281
281
pil_images .append (load_image_from_base64 (image_str ))
282
- prefix += f'Image-{ global_image_cnt + 1 } : <image>\n \n '
282
+ prefix += f'Image-{ global_image_cnt + 1 } : <image>\n '
283
283
global_image_cnt += 1
284
284
max_input_tile_temp .append (max (1 , max_input_tiles // len (message ['image' ])))
285
285
if len (max_input_tile_temp ) > 0 :
@@ -291,8 +291,8 @@ def generate_stream(self, params):
291
291
question , history = history [- 1 ][0 ], history [:- 1 ]
292
292
293
293
if global_image_cnt == 1 :
294
- question = question .replace ('Image-1: <image>\n \n ' , '<image>\n ' )
295
- history = [[item [0 ].replace ('Image-1: <image>\n \n ' , '<image>\n ' ), item [1 ]] for item in history ]
294
+ question = question .replace ('Image-1: <image>\n ' , '<image>\n ' )
295
+ history = [[item [0 ].replace ('Image-1: <image>\n ' , '<image>\n ' ), item [1 ]] for item in history ]
296
296
297
297
# Create a new list to store processed sublists
298
298
flattened_list = []
@@ -308,7 +308,7 @@ def generate_stream(self, params):
308
308
309
309
old_system_message = self .model .system_message
310
310
self .model .system_message = system_message
311
- image_tiles = []
311
+ image_tiles , num_patches_list = [], []
312
312
transform = build_transform (input_size = self .image_size )
313
313
if len (pil_images ) > 0 :
314
314
for current_max_input_tiles , pil_image in zip (max_input_tile_list , pil_images ):
@@ -318,6 +318,7 @@ def generate_stream(self, params):
318
318
use_thumbnail = self .model .config .use_thumbnail )
319
319
else :
320
320
tiles = [pil_image ]
321
+ num_patches_list .append (len (tiles ))
321
322
image_tiles += tiles
322
323
pixel_values = [transform (item ) for item in image_tiles ]
323
324
pixel_values = torch .stack (pixel_values ).to (self .model .device , dtype = torch .bfloat16 )
@@ -341,6 +342,7 @@ def generate_stream(self, params):
341
342
thread = Thread (target = self .model .chat , kwargs = dict (
342
343
tokenizer = self .tokenizer ,
343
344
pixel_values = pixel_values ,
345
+ num_patches_list = num_patches_list ,
344
346
question = question ,
345
347
history = history ,
346
348
return_history = False ,
0 commit comments