Skip to content

Is batch inference enabled now? #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Cola-any opened this issue Apr 15, 2025 · 1 comment
Open

Is batch inference enabled now? #635

Cola-any opened this issue Apr 15, 2025 · 1 comment

Comments

@Cola-any
Copy link

I found that the GPU memory usage is relatively small when batch_size=1 is setted for inference.
I want to make full use of the GPU by assigning a larger batch_size. But I encountered the following error, can anyone help me?

Traceback (most recent call last): File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/__main__.py", line 330, in cli_evaluate results, samples = cli_evaluate_single(args) File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/__main__.py", line 471, in cli_evaluate_single results = evaluator.simple_evaluate( File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/utils.py", line 533, in _wrapper return fn(*args, **kwargs) File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/evaluator.py", line 177, in simple_evaluate lm = lmms_eval.models.get_model(model).create_from_arg_string( File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/api/model.py", line 111, in create_from_arg_string return cls(**args, **args2) File "/home/user/MLLM/LLaVA-NeXT/lmms-eval/lmms_eval/models/llava_onevision.py", line 148, in __init__ assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." AssertionError: Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue.
By the way, I am evaluating LLaVA-OneVision on videomme.

@CLARKBENHAM
Copy link
Contributor

CLARKBENHAM commented Apr 17, 2025

There's some issues with their muning before making model requests when using batch>1. It wouldn't work for me either.

For the ai2d dataset with llava and "--model_args", "pretrained=lmms-lab/llama3-llava-next-8b,conv_template=llava_llama_3", I stepped into the debugger.

At models/llava.py:391

cont = self.model.generate(
    input_ids,
    attention_mask=attention_masks,
    pad_token_id=pad_token_ids,
    images=image_tensor,
    image_sizes=gen_kwargs["image_sizes"],
    do_sample=True if gen_kwargs["temperature"] > 0 else False,
    temperature=gen_kwargs["temperature"],
    top_p=gen_kwargs["top_p"],
    num_beams=gen_kwargs["num_beams"],
    max_new_tokens=gen_kwargs["max_new_tokens"],
    use_cache=self.use_cache,
)
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)

Inputs seem right shape for batch size 8, but outputs are [1,3] and a giberish string.

> input_ids.shape
torch.Size([8, 234])
> len(image_tensor)
8
> [i.shape for i in image_tensor]
[torch.Size([3, 3, 336, 336]), torch.Size([5, 3, 336, 336]), torch.Size([3, 3, 336, 336]), torch.Size([5, 3, 336, 336]), torch.Size([5, 3, 336, 336]), torch.Size([3, 3, 336, 336]), torch.Size([5, 3, 336, 336]), torch.Size([5, 3, 336, 336])]
> attention_masks.sum(axis=1)
tensor([234, 214, 179, 172, 166, 164, 155, 153], device='cuda:0')
> text_outputs
['\nD']

Stepping into llava/model/language_model/llava_llama.py:137 I see that the preperation step screws up the dimensions of things.

> inputs.shape
torch.Size([8, 151])
> position_ids
None
> attention_mask.sum(axis=1)
tensor([151, 148, 142, 142, 140, 139, 139, 137], device='cuda:0')
From (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
> inputs_embeds.shape
torch.Size([1, 1426, 4096])
>print(attention_mask.sum(), attention_mask.shape)
tensor(1426, device='cuda:0') torch.Size([1, 1426])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants