-
Notifications
You must be signed in to change notification settings - Fork 567
Open
Labels
bugSomething isn't workingSomething isn't working
Description
System Info
Python 3.12
/app/models/pixtral_official_checkpoint directory contains the weights from https://huggingface.co/mistral-community/pixtral-12b
Who can help?
@JingyaHuang @echarlaix @michaelbenayoun
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
from optimum.exporters.onnx import main_export
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import LlavaForConditionalGeneration, AutoTokenizer, pipeline
def export(output_path='onnx_language_model'):
model_id = '/app/models/pixtral_official_checkpoint'
model = LlavaForConditionalGeneration.from_pretrained(model_id)
language_model = model.language_model
language_model.config.use_cache = True
model.language_model.save_pretrained('language_model')
export_kwargs = dict(
model_name_or_path='pytorch_language_model',
output=output_path,
task="text-generation-with-past",
device='cpu',
trust_remote_code=True,
use_cache=True
)
main_export(**export_kwargs)
def main():
export()
if __name__ == "__main__":
main()
Expected behavior
I expect the Mistral language model used in Pixtral to be exported into an ONNX model. However, I run into an issue when I set use_cache=True
. I get an issue where the cache can't update due to mismatched sizes:
File "/app/.venv/lib/python3.12/site-packages/transformers/cache_utils.py", line 446, in update
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 160 but got size 128 for tensor number 1 in the list.
I don't get this issue when I set use_cache=False
and set task='text-generation'
, however I would like the ONNX model to use cache during inference.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working