Skip to content

Commit 0f91eb8

Browse files
vanbasten23patrickvonplaten
authored andcommitted
Use w8a8 quantized matmul Pallas kernel (vllm-project#19170)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 7f59720 commit 0f91eb8

File tree

4 files changed

+50
-19
lines changed

4 files changed

+50
-19
lines changed

requirements/tpu.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ setuptools==78.1.0
1818
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1919
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2020
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
21-
torch==2.9.0.dev20250703
22-
torchvision==0.24.0.dev20250703
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
21+
torch==2.9.0.dev20250711
22+
torchvision==0.24.0.dev20250711
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626

tests/tpu/test_quantization_accuracy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
@dataclass
1515
class GSM8KAccuracyTestConfig:
1616
model_name: str
17-
excepted_value: float
17+
expected_value: float
1818

1919
def get_model_args(self) -> str:
2020
return (f"pretrained={self.model_name},"
@@ -25,13 +25,13 @@ def get_model_args(self) -> str:
2525
ACCURACY_CONFIGS = [
2626
GSM8KAccuracyTestConfig(
2727
model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8",
28-
excepted_value=0.76), # no bias
28+
expected_value=0.76), # no bias
2929
# NOTE(rob): We cannot re-initialize vLLM in the same process for TPU,
3030
# so only one of these tests can run in a single call to pytest. As
3131
# a follow up, move this into the LM-EVAL section of the CI.
3232
# GSM8KAccuracyTestConfig(
3333
# model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8",
34-
# excepted_value=0.66), # bias in QKV layers
34+
# expected_value=0.66), # bias in QKV layers
3535
]
3636

3737

@@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig):
4545
batch_size="auto",
4646
)
4747

48-
EXPECTED_VALUE = config.excepted_value
48+
EXPECTED_VALUE = config.expected_value
4949
measured_value = results["results"][TASK][FILTER]
5050
assert (measured_value - RTOL < EXPECTED_VALUE
5151
and measured_value + RTOL > EXPECTED_VALUE

tests/v1/tpu/test_basic.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp(
145145
for output, answer in zip(vllm_outputs, answers):
146146
generated_text = output[1]
147147
assert answer in generated_text
148+
149+
150+
@pytest.mark.skipif(not current_platform.is_tpu(),
151+
reason="This is a basic test for TPU only")
152+
def test_w8a8_quantization(
153+
vllm_runner: type[VllmRunner],
154+
monkeypatch: pytest.MonkeyPatch,
155+
) -> None:
156+
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
157+
max_tokens = 5
158+
tensor_parallel_size = 1
159+
max_num_seqs = 4
160+
161+
prompt = "The next numbers of the sequence " + ", ".join(
162+
str(i) for i in range(1024)) + " are:"
163+
example_prompts = [prompt]
164+
165+
with monkeypatch.context() as m:
166+
m.setenv("VLLM_USE_V1", "1")
167+
168+
with vllm_runner(
169+
model,
170+
max_num_batched_tokens=64,
171+
max_model_len=4096,
172+
gpu_memory_utilization=0.7,
173+
max_num_seqs=max_num_seqs,
174+
tensor_parallel_size=tensor_parallel_size) as vllm_model:
175+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
176+
max_tokens)
177+
output = vllm_outputs[0][1]
178+
179+
assert "1024" in output or "0, 1" in output

vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,15 @@ def apply_weights(self,
9090
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
9191
w_q, w_s, _, _, _ = self._get_weight_params(layer)
9292

93-
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
94-
out = torch.ops.xla.quantized_matmul(x,
95-
w_q,
96-
w_s,
97-
zero_point=None,
98-
block_size=-1,
99-
int4_weight=False,
100-
quantize_activation=True)
101-
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
102-
out = out.to(x.dtype)
93+
# Required to register custom ops.
94+
import torch_xla.experimental.custom_kernel # noqa: F401
95+
out = torch.ops.xla.quantized_matmul_int8(
96+
x,
97+
w_q,
98+
w_s,
99+
quantize_activation=True,
100+
)
101+
103102
# Explicitly capture control flow to make dynamo happy.
104103
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
105104
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

0 commit comments

Comments
 (0)