Skip to content

add DeepseekV3 AWQ mapping #1619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

add DeepseekV3 AWQ mapping #1619

wants to merge 10 commits into from

Conversation

cjackal
Copy link

@cjackal cjackal commented Jul 3, 2025

SUMMARY:

Add AWQ activation-smooth mapping for DeepseekV3ForCausalLM.

TEST PLAN:

examples/quantizing_moe/deepseek_r1_example.py but recipe adapted to use AWQModifier instead:

from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.transformers import oneshot

# Select model and load it.

# This script takes about 48 hours on 1xA100 to complete.
# Future improvements will reduce this runtime (#1561, #1558).

# For DeepSeek-R1, we require a full precision model in order to properly calibrate
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16

model_id = "unsloth/DeepSeek-R1-0528-BF16"
config = AutoConfig.from_pretrained(model_id)
del config.quantization_config  # fp8 qconfig no longer appplies to bf16 model
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype="auto", config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_calibration(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = AWQModifier(
    targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"]
)

# Apply algorithms.
# due to the large size of DeepSeekV3, we specify sequential targets such that
# only one MLP is loaded into GPU memory at a time
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
    sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"],
)

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @cjackal, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for quantizing DeepseekV3 models using the AWQ (Activation-aware Weight Quantization) method. It specifically adds the necessary activation-smooth mappings required for DeepseekV3ForCausalLM within the llmcompressor framework, enabling more efficient deployment of these models.

Highlights

  • DeepseekV3 AWQ Support: I've added specific AWQMapping configurations for DeepseekV3ForCausalLM to enable activation-smooth quantization. These mappings define the relationships between various normalization and projection layers (e.g., input_layernorm to q_a_proj and kv_a_proj_with_mqa, up_proj to down_proj) crucial for the AWQ algorithm.
  • Registry Update: I've registered the new DeepseekV3 mappings in the AWQ_MAPPING_REGISTRY to ensure DeepseekV3ForCausalLM models can correctly utilize the AWQ modifier within the llmcompressor framework.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Activation-aware Weight Quantization (AWQ) for the DeepseekV3 model architecture. This is achieved by defining a new set of layer mappings specific to DeepseekV3 and registering them. The changes are clear and follow the existing structure for defining architecture-specific mappings. I have one suggestion to improve code maintainability.

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Copy link

github-actions bot commented Jul 3, 2025

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
@cjackal
Copy link
Author

cjackal commented Jul 3, 2025

It correctly counts the number of calibrations (14971, the same number as GPTQModifier) but somehow some weights are not loaded in forward path. Need to investigate further.

(9/14971): Calibrating: 100%|█████████▉| 511/512 [00:14<00:00, 29.70it/s]
(9/14971): Calibrating: 100%|██████████| 512/512 [00:14<00:00, 35.13it/s]
Smoothing: 0%| | 0/2 [00:00<?, ?it/s]
Smoothing: 0%| | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/jovyan/git/awq-test/quant_r1.py", line 123, in <module>
    oneshot(
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 308, in oneshot
    one_shot()
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 149, in __call__
    self.apply_recipe_modifiers(
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/entrypoints/oneshot.py", line 192, in apply_recipe_modifiers
    pipeline(self.model, calibration_dataloader, self.dataset_args)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/pipelines/independent/pipeline.py", line 45, in __call__
    pipeline(model, dataloader, dataset_args)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/pipelines/sequential/pipeline.py", line 88, in __call__
    LifecycleCallbacks.sequential_epoch_end()
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/core/session_functions.py", line 154, in sequential_epoch_end
    return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/core/session_functions.py", line 78, in event
    return active_session().event(event_type, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/core/session.py", line 179, in event
    mod_data = self._lifecycle.event(
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/core/lifecycle.py", line 204, in event
    data = mod.update_event(state=self.state, event=event, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/modifiers/modifier.py", line 119, in update_event
    self.on_event(state, event, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/modifiers/awq/base.py", line 250, in on_event
    self._apply_smoothing(state.model)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/llmcompressor/modifiers/awq/base.py", line 461, in _apply_smoothing
    weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 308, in _fn
    result = fn(*args, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 149, in _fn
    result = fn(**bound.arguments)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_refs/__init__.py", line 2768, in cat
    utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
  File "/home/jovyan/git/awq-test/.venv/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 779, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cuda:0!

@dsikka dsikka requested a review from brian-dellabetta July 3, 2025 15:18
rahul-tuli
rahul-tuli previously approved these changes Jul 4, 2025
@rahul-tuli rahul-tuli dismissed their stale review July 4, 2025 11:48

Need to verify the error reported

@cjackal
Copy link
Author

cjackal commented Jul 4, 2025

It seems like GPTQ does not exploit torch.compile but my AWQ test script does. Let me turn torch.compile off and see if it works.

@brian-dellabetta
Copy link
Collaborator

Thank you @cjackal for the contribution! Please let us know how it goes with torch.compile off, I am not sure why you are hitting device on meta errors in this case.

@casper-hansen
Copy link

Any updates on this? With the new Kimi K2 release, there is a lot of renewed interest in quantizing the DeepSeek V3 architecture.

Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented Jul 14, 2025

Any updates on this? With the new Kimi K2 release, there is a lot of renewed interest in quantizing the DeepSeek V3 architecture.

Hi @casper-hansen , we were talking about Kimi K2 internally today, if it's feasible to run AWQ or our other compression algorithms on a single GPU, layer-by-layer. These mappings looks correct, though I haven't had a chance to validate on deepseek or k2. But it uses the same DeepseekV3ForCausalLM architecture, it is more a question if the user has the hardware capabilities to run AWQ.

I am trying to validate @cjackal 's example script to try to get this PR in soon.

@casper-hansen
Copy link

@brian-dellabetta DeepSeek V3 and R1 was quantized successfully in AutoAWQ. You do need a machine with a lot of system RAM but it works just like any other model. We also had to convert to bfloat16 before quantizing.

@cjackal
Copy link
Author

cjackal commented Jul 15, 2025

Sorry for late response; it wasn't the torch.compile that causes the error.

I am debugging it but I'd first like to share my finding here so that the repo maintainers have a chance to give a quick guidance for the fix.

The error is hit the first time when the AWQ weight mapping having DeepseekV3MoECalibrate as the parent module is run; I am new to the llm-compressor codebase, but unlike autoawq the weight mappings in llm-compressor is implicit so it looks like there is no way for the user to explicitly designate the parent name.

Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
@brian-dellabetta
Copy link
Collaborator

Hi @cjackal , @casper-hansen

I had to push some updates to cjackal's branch to prevent GPU OOM errors, but it is working through the layers with experts now:

2025-07-15T16:56:59.526747+0000 | on_initialize | INFO - No AWQModifier.mappings provided, inferring from model...
Resolving mapping 1/5 (0 skipped): 100%|██████████████████████████████████████| 61/61 [00:02<00:00, 27.78it/s]
Resolving mapping 2/5 (0 skipped): 100%|██████████████████████████████████████| 61/61 [00:00<00:00, 6834.03it/s]
Resolving mapping 3/5 (0 skipped): 100%|██████████████████████████████████████| 61/61 [00:00<00:00, 6852.70it/s]
Resolving mapping 4/5 (0 skipped): 100%|██████████████████████████████████████| 61/61 [00:02<00:00, 28.15it/s]
Resolving mapping 5/5 (0 skipped): 100%|██████████████████████████████████████| 14909/14909 [00:03<00:00, 4474.10it/s]
2025-07-15T16:57:15.673339+0000 | initialize | INFO - Compression lifecycle initialized for 1 modifiers
2025-07-15T16:57:15.673489+0000 | IndependentPipeline | INFO - Inferred `SequentialPipeline` for `AWQModifier`
Preparing cache: 100%|██████████████| 128/128 [00:00<00:00, 1987.98it/s]
(1/62): Calibrating: 100%|██████████| 128/128 [00:03<00:00, 34.98it/s]
Smoothing: 100%|████████████████████| 5/5 [00:15<00:00,  3.12s/it]
(1/62): Propagating: 100%|██████████| 128/128 [00:00<00:00, 417.62it/s]
(2/62): Calibrating: 100%|██████████| 128/128 [00:01<00:00, 70.98it/s]
Smoothing: 100%|████████████████████| 5/5 [00:15<00:00,  3.06s/it]
(2/62): Propagating: 100%|██████████| 128/128 [00:00<00:00, 407.90it/s]
(3/62): Calibrating: 100%|██████████| 128/128 [00:01<00:00, 96.58it/s]
Smoothing: 100%|████████████████████| 5/5 [00:15<00:00,  3.07s/it]
(3/62): Propagating: 100%|██████████| 128/128 [00:00<00:00, 405.84it/s]
(4/62): Calibrating: 100%|██████████| 128/128 [00:19<00:00,  6.47it/s]
Smoothing: 100%|████████████████████| 261/261 [06:42<00:00,  1.54s/it]
(4/62): Propagating: 100%|██████████| 128/128 [00:07<00:00, 18.10it/s]
(5/62): Calibrating: 100%|██████████| 128/128 [02:59<00:00,  1.40s/it]
Smoothing: 100%|████████████████████| 261/261 [06:50<00:00,  1.57s/it]
(5/62): Propagating: 100%|██████████| 128/128 [00:07<00:00, 16.55it/s]
(6/62): Calibrating: 100%|██████████| 128/128 [02:45<00:00,  1.29s/it]
Smoothing:   0%|      #manually exited here

This is hitting about ~60GB peak GPU RAM (128 samples with 512 max sequence length). Extrapolating, this would take 8-10 hours on an H100, but I'm running on a noisy server with lots of other processes running. @casper-hansen do you recall what memory/time requirements AutoAWQ had for DeepseekV3 with 128 samples and 512 max sequence length?

@cjackal I made a few modifications to your script, along with merging into this branch some new changes we've recently pushed to main, it resolves the error you were hitting. My script is attached below. Feel free to try it out, maybe torch.compile can help out too. I was experimenting with that on #1557 but other priorities have popped up.

I need to switch gears for the next few days, but will keep an eye on this thread. I can revisit towards the end of the week.

import torch
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modeling import prepare_for_calibration
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor import oneshot

# Select model and load it.

# This script takes about 48 hours on 1xA100 to complete.
# Future improvements will reduce this runtime (#1561, #1558).

# For DeepSeek-R1, we require a full precision model in order to properly calibrate
# `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16

model_id = "unsloth/DeepSeek-R1-0528-BF16"
config = AutoConfig.from_pretrained(model_id)
if hasattr(config, "quantization_config"):
    del config.quantization_config  # fp8 qconfig no longer appplies to bf16 model
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype="auto", config=config
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = prepare_for_calibration(model)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 128
MAX_SEQUENCE_LENGTH = 512

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = AWQModifier(
    targets="Linear",
    scheme="W4A16",
    ignore=["lm_head", "re:.*mlp.gate$"],
    offload_device=torch.device("cpu"),
)

# Apply algorithms.
# model is loaded sequentially, automatically onto GPU
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

@cjackal
Copy link
Author

cjackal commented Jul 16, 2025

@brian-dellabetta Let me test again with the head commit and remove the draft tag from this PR if successful.

@casper-hansen
Copy link

@brian-dellabetta about 24 hours to quantize the full model. First few layers go fast. I don’t have memory stats.

@brian-dellabetta
Copy link
Collaborator

Thanks @cjackal and @casper-hansen for the information! Feeling good that our implementation seems to be working with such a large model. @cjackal please let me know how it goes, we are looking into k2 as well

@cjackal cjackal force-pushed the main branch 2 times, most recently from b737681 to 5288bec Compare July 16, 2025 15:57
cjackal added 2 commits July 17, 2025 00:58
….models`

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
@cjackal
Copy link
Author

cjackal commented Jul 16, 2025

So sequential onloading was the culprit in my test script, I shouldn't have copy-paste the GPTQ script blindfolded.

@brian-dellabetta With your test script I got past the 9th layer (I also run on H100 but mine is fully isolated) in about an hour, which looks rosy. Let me turn this PR into ready and wait to complete the job.

BTW current main is incompatible with transformers<4.52 as star-import in transformers.models is introduced in transformers==4.52.0 thus the test script raises an import error at here.

@cjackal cjackal marked this pull request as ready for review July 16, 2025 16:05
@brian-dellabetta
Copy link
Collaborator

brian-dellabetta commented Jul 16, 2025

Thanks @cjackal , glad to hear it's working rosy now. Thanks for the heads-up on the transformers issue, 4.52.0 broke some other things for us, recommend to always be on latest version even though our pins are very loose

Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one nit, otherwise this looks good. I can approve ci/cd to run

cjackal and others added 2 commits July 17, 2025 00:35
Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
…nsformers versions

Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution and tests! I will kick off CI/CD and hopefully get this in soon

Comment on lines +4 to +7
from transformers.models.llama4.configuration_llama4 import (
Llama4Config,
Llama4TextConfig,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these changes allow successful run with older versions of transformers allowed by our transformers>4 pin

@brian-dellabetta brian-dellabetta added the ready When a PR is ready for review label Jul 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants