Skip to content

Commit 39a0a11

Browse files
authored
[Model] Mistral3 example and test (#1490)
## Purpose ## * Add support for mistral3 * Related: #1343 ## Prerequisites ## * #1479 ## Changes ## * Added mistral3 example * This model does not automatically change the dtype of pixel_values to match the dtype of the model, so I had to do so manually in the collator and sample generation * This model has a [very verbose chat template by default](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json), which may be less conducive to calibration, so I added a custom shortened version ## Testing ## * Ran example to completion: [nm-testing/Mistral-Small-3.1-24B-Instruct-2503-W4A16-G128](https://huggingface.co/nm-testing/Mistral-Small-3.1-24B-Instruct-2503-W4A16-G128) --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6d881f7 commit 39a0a11

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"chat_template": "{%- set default_system_message = \"You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI\" %}\n\n{{- bos_token }}\n\n{%- if messages[0]['role'] == 'system' %}\n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content'] %}\n {%- else %}\n {%- set system_message = messages[0]['content'][0]['text'] %}\n {%- endif %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set system_message = default_system_message %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n\n{%- for message in loop_messages %}\n {%- if message['role'] == 'user' %}\n {%- if message['content'] is string %}\n {{- '[INST]' + message['content'] + '[/INST]' }}\n {%- else %}\n {{- '[INST]' }}\n {%- for block in message['content'] %}\n {%- if block['type'] == 'text' %}\n {{- block['text'] }}\n {%- elif block['type'] in ['image', 'image_url'] %}\n {{- '[IMG]' }}\n {%- else %}\n {{- raise_exception('Only text and image blocks are supported in message content!') }}\n {%- endif %}\n {%- endfor %}\n {{- '[/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'system' %}\n {%- if message['content'] is string %}\n {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}\n {%- else %}\n {{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {%- if message['content'] is string %}\n {{- message['content'] + eos_token }}\n {%- else %}\n {{- message['content'][0]['text'] + eos_token }}\n {%- endif %}\n {%- else %}\n {{- raise_exception('Only user, system and assistant roles are supported!') }}\n {%- endif %}\n{%- endfor %}"
3+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import json
2+
import os
3+
4+
import requests
5+
import torch
6+
from PIL import Image
7+
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
8+
9+
from llmcompressor import oneshot
10+
from llmcompressor.modifiers.quantization import GPTQModifier
11+
12+
# Load model.
13+
model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
14+
model = Mistral3ForConditionalGeneration.from_pretrained(
15+
model_id, device_map="auto", torch_dtype="auto"
16+
)
17+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18+
19+
# Use a custom calibration chat template, rather than the overly-verbose default
20+
file_path = os.path.join(os.path.dirname(__file__), "mistral3_chat_template.json")
21+
with open(file_path, "r") as file:
22+
processor.chat_template = json.load(file)["chat_template"]
23+
24+
# Oneshot arguments
25+
DATASET_ID = "flickr30k"
26+
DATASET_SPLIT = "test"
27+
NUM_CALIBRATION_SAMPLES = 512
28+
MAX_SEQUENCE_LENGTH = 2048
29+
30+
31+
# Define a oneshot data collator for multimodal inputs.
32+
def data_collator(batch):
33+
assert len(batch) == 1
34+
return {
35+
key: torch.tensor(value)
36+
if key != "pixel_values"
37+
else torch.tensor(value, dtype=model.dtype)
38+
for key, value in batch[0].items()
39+
}
40+
41+
42+
# Recipe
43+
recipe = [
44+
GPTQModifier(
45+
targets="Linear",
46+
scheme="W4A16",
47+
sequential_targets=["MistralDecoderLayer"],
48+
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
49+
),
50+
]
51+
52+
# Perform oneshot
53+
oneshot(
54+
model=model,
55+
tokenizer=model_id,
56+
dataset=DATASET_ID,
57+
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
58+
recipe=recipe,
59+
max_seq_length=MAX_SEQUENCE_LENGTH,
60+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
61+
trust_remote_code_model=True,
62+
data_collator=data_collator,
63+
)
64+
65+
# Confirm generations of the quantized model look sane.
66+
print("========== SAMPLE GENERATION ==============")
67+
messages = [
68+
{
69+
"role": "user",
70+
"content": [
71+
{"type": "text", "text": "Please describe the animal in this image\n"},
72+
{"type": "image"},
73+
],
74+
},
75+
]
76+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
77+
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
78+
raw_image = Image.open(requests.get(image_url, stream=True).raw)
79+
80+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
81+
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) # fix dtype
82+
output = model.generate(**inputs, max_new_tokens=100)
83+
print(processor.decode(output[0], skip_special_tokens=True))
84+
print("==========================================")
85+
86+
# Save to disk compressed.
87+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
88+
model.save_pretrained(SAVE_DIR, save_compressed=True)
89+
processor.save_pretrained(SAVE_DIR)

tests/llmcompressor/transformers/tracing/test_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Idefics3ForConditionalGeneration,
88
Llama4ForConditionalGeneration,
99
LlavaForConditionalGeneration,
10+
Mistral3ForConditionalGeneration,
1011
MllamaForConditionalGeneration,
1112
Qwen2_5_VLForConditionalGeneration,
1213
Qwen2VLForConditionalGeneration,
@@ -86,6 +87,13 @@
8687
"vision",
8788
["torchvision"],
8889
),
90+
(
91+
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
92+
Mistral3ForConditionalGeneration,
93+
["MistralDecoderLayer"],
94+
"vision",
95+
[],
96+
),
8997
(
9098
"google/gemma-3-4b-it",
9199
Gemma3ForConditionalGeneration,

0 commit comments

Comments
 (0)