Skip to content

Commit 7412637

Browse files
committed
Fix bugs in multi-image inference
1 parent c62fa4f commit 7412637

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

streamlit_demo/app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def get_model_list():
4949
assert ret.status_code == 200
5050
ret = requests.post(controller_url + '/list_models')
5151
models = ret.json()['models']
52+
models = [item for item in models if 'InternVL2-Det' not in item]
5253
return models
5354

5455

@@ -141,6 +142,8 @@ def generate_response(messages):
141142
else:
142143
output = data['text'] + f" (error_code: {data['error_code']})"
143144
placeholder.markdown(output)
145+
if ('\[' in output and '\]' in output) or ('\(' in output and '\)' in output):
146+
output = output.replace('\[', '$').replace('\]', '$').replace('\(', '$').replace('\)', '$')
144147
placeholder.markdown(output)
145148
except requests.exceptions.RequestException as e:
146149
placeholder.markdown(server_error_msg)

streamlit_demo/model_worker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def generate_stream(self, params):
279279
max_input_tile_temp = []
280280
for image_str in message['image']:
281281
pil_images.append(load_image_from_base64(image_str))
282-
prefix += f'Image-{global_image_cnt + 1}: <image>\n\n'
282+
prefix += f'Image-{global_image_cnt + 1}: <image>\n'
283283
global_image_cnt += 1
284284
max_input_tile_temp.append(max(1, max_input_tiles // len(message['image'])))
285285
if len(max_input_tile_temp) > 0:
@@ -291,8 +291,8 @@ def generate_stream(self, params):
291291
question, history = history[-1][0], history[:-1]
292292

293293
if global_image_cnt == 1:
294-
question = question.replace('Image-1: <image>\n\n', '<image>\n')
295-
history = [[item[0].replace('Image-1: <image>\n\n', '<image>\n'), item[1]] for item in history]
294+
question = question.replace('Image-1: <image>\n', '<image>\n')
295+
history = [[item[0].replace('Image-1: <image>\n', '<image>\n'), item[1]] for item in history]
296296

297297
# Create a new list to store processed sublists
298298
flattened_list = []
@@ -308,7 +308,7 @@ def generate_stream(self, params):
308308

309309
old_system_message = self.model.system_message
310310
self.model.system_message = system_message
311-
image_tiles = []
311+
image_tiles, num_patches_list = [], []
312312
transform = build_transform(input_size=self.image_size)
313313
if len(pil_images) > 0:
314314
for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
@@ -318,6 +318,7 @@ def generate_stream(self, params):
318318
use_thumbnail=self.model.config.use_thumbnail)
319319
else:
320320
tiles = [pil_image]
321+
num_patches_list.append(len(tiles))
321322
image_tiles += tiles
322323
pixel_values = [transform(item) for item in image_tiles]
323324
pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
@@ -341,6 +342,7 @@ def generate_stream(self, params):
341342
thread = Thread(target=self.model.chat, kwargs=dict(
342343
tokenizer=self.tokenizer,
343344
pixel_values=pixel_values,
345+
num_patches_list=num_patches_list,
344346
question=question,
345347
history=history,
346348
return_history=False,

0 commit comments

Comments
 (0)