Skip to content

Exporting Mistral to ONNX Model #2312

@EricJi150

Description

@EricJi150

System Info

Python 3.12
/app/models/pixtral_official_checkpoint directory contains the weights from https://huggingface.co/mistral-community/pixtral-12b

Who can help?

@JingyaHuang @echarlaix @michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

from optimum.exporters.onnx import main_export
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import LlavaForConditionalGeneration, AutoTokenizer, pipeline

def export(output_path='onnx_language_model'):
    model_id = '/app/models/pixtral_official_checkpoint'
    model = LlavaForConditionalGeneration.from_pretrained(model_id)
    language_model = model.language_model
    language_model.config.use_cache = True
    model.language_model.save_pretrained('language_model')
    export_kwargs = dict(
        model_name_or_path='pytorch_language_model',
        output=output_path,
        task="text-generation-with-past",
        device='cpu',
        trust_remote_code=True,
        use_cache=True
    )
    main_export(**export_kwargs)

def main():
    export()

if __name__ == "__main__":
    main()

Expected behavior

I expect the Mistral language model used in Pixtral to be exported into an ONNX model. However, I run into an issue when I set use_cache=True. I get an issue where the cache can't update due to mismatched sizes:

File "/app/.venv/lib/python3.12/site-packages/transformers/cache_utils.py", line 446, in update
    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 160 but got size 128 for tensor number 1 in the list.

I don't get this issue when I set use_cache=False and set task='text-generation', however I would like the ONNX model to use cache during inference.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions