Skip to content

Commit f4bcd2d

Browse files
committed
integration-vllm-test
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 1017c7e commit f4bcd2d

File tree

1 file changed

+281
-0
lines changed

1 file changed

+281
-0
lines changed

test/integration/test_vllm.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
from typing import List, Tuple
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
18+
19+
if not TORCH_VERSION_AT_LEAST_2_7:
20+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
21+
22+
23+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
24+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
25+
26+
if not VLLM_AVAILABLE:
27+
pytest.skip("vLLM not installed", allow_module_level=True)
28+
29+
if not TRANSFORMERS_AVAILABLE:
30+
pytest.skip("transformers not installed", allow_module_level=True)
31+
32+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
33+
from vllm import LLM, SamplingParams
34+
35+
from torchao.prototype.mx_formats import MXGemmKernelChoice
36+
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
37+
from torchao.quantization.granularity import PerRow, PerTensor
38+
from torchao.quantization.quant_api import (
39+
CutlassInt4PackedLayout,
40+
Float8DynamicActivationFloat8WeightConfig,
41+
GemliteUIntXWeightOnlyConfig,
42+
Int4DynamicActivationInt4WeightConfig,
43+
Int4WeightOnlyConfig,
44+
Int8DynamicActivationInt4WeightConfig,
45+
Int8DynamicActivationInt8WeightConfig,
46+
Int8WeightOnlyConfig,
47+
)
48+
49+
50+
def get_tests() -> List[Tuple[str, str]]:
51+
"""Get all the tests based off of device info"""
52+
53+
BASE_TESTS = [("int8_weight_only", "per_tensor")]
54+
SM89_TESTS = [("fp8", "per_tensor"), ("fp8", "per_row")]
55+
SM90_ONLY_TESTS = [("A8W4", "per_tensor")]
56+
SM100_TESTS = [("mxfp8", "per_tensor")]
57+
58+
# Check CUDA availability first
59+
if not torch.cuda.is_available():
60+
return [] # No CUDA, no tests
61+
62+
major, minor = torch.cuda.get_device_capability()
63+
64+
# Build test list based on compute capability
65+
all_tests = []
66+
67+
# Always include base tests if we have CUDA
68+
all_tests.extend(BASE_TESTS)
69+
70+
# Add SM89+ tests
71+
if major > 8 or (major == 8 and minor >= 9):
72+
all_tests.extend(SM89_TESTS)
73+
74+
# Add SM100+ tests
75+
if major >= 10:
76+
all_tests.extend(SM100_TESTS)
77+
78+
# Only work for sm 90
79+
if major == 9:
80+
all_tests.extend(SM90_ONLY_TESTS)
81+
82+
return all_tests
83+
84+
85+
class TestVLLMIntegration:
86+
"""Integration tests for vLLM with quantized models."""
87+
88+
@classmethod
89+
def setup_class(cls):
90+
"""Set up test environment."""
91+
# Set seeds for reproducibility
92+
cls.set_seed(42)
93+
94+
# Set vLLM environment variables
95+
os.environ["VLLM_USE_V1"] = "1"
96+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
97+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
98+
99+
@classmethod
100+
def teardown_class(cls):
101+
"""Clean up after all tests."""
102+
torch.cuda.empty_cache()
103+
import gc
104+
105+
gc.collect()
106+
107+
def setup_method(self, method):
108+
"""Clean up before each test method."""
109+
torch.cuda.empty_cache()
110+
import gc
111+
112+
gc.collect()
113+
114+
def teardown_method(self, method):
115+
"""Clean up after each test method."""
116+
torch.cuda.empty_cache()
117+
import gc
118+
119+
gc.collect()
120+
121+
@staticmethod
122+
def set_seed(seed):
123+
"""Set random seeds for reproducibility."""
124+
random.seed(seed)
125+
np.random.seed(seed)
126+
torch.manual_seed(seed)
127+
torch.cuda.manual_seed_all(seed)
128+
129+
def get_quantization_config(self, quant_type: str, granularity: str = "per_tensor"):
130+
"""Create TorchAo quantization config based on provided parameters."""
131+
granularity_mapping = {
132+
"per_row": PerRow(),
133+
"per_tensor": PerTensor(),
134+
}
135+
136+
gran = granularity_mapping[granularity]
137+
138+
if quant_type == "autoquant":
139+
return TorchAoConfig("autoquant", min_sqnr=40.0)
140+
elif quant_type == "fp8":
141+
return TorchAoConfig(
142+
Float8DynamicActivationFloat8WeightConfig(granularity=gran)
143+
)
144+
elif quant_type == "int4_weight_only":
145+
return TorchAoConfig(Int4WeightOnlyConfig(group_size=128))
146+
elif quant_type == "int8_weight_only":
147+
return TorchAoConfig(Int8WeightOnlyConfig())
148+
elif quant_type == "int8_dynamic_act_int8_weight":
149+
return TorchAoConfig(Int8DynamicActivationInt8WeightConfig())
150+
elif quant_type == "gemlite":
151+
return TorchAoConfig(GemliteUIntXWeightOnlyConfig())
152+
elif quant_type == "A4W4":
153+
return TorchAoConfig(Int4DynamicActivationInt4WeightConfig())
154+
elif quant_type == "A8W4":
155+
return TorchAoConfig(
156+
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
157+
)
158+
elif quant_type == "mxfp8":
159+
return TorchAoConfig(MXFPInferenceConfig())
160+
elif quant_type == "mxfp4":
161+
return TorchAoConfig(
162+
MXFPInferenceConfig(
163+
activation_dtype=torch.float4_e2m1fn_x2,
164+
weight_dtype=torch.float4_e2m1fn_x2,
165+
block_size=32,
166+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
167+
)
168+
)
169+
else:
170+
raise ValueError(f"Unsupported quantization type: {quant_type}")
171+
172+
def quantize_and_save_model(
173+
self,
174+
model_name: str,
175+
quant_type: str,
176+
output_dir: Path,
177+
granularity: str = "per_tensor",
178+
):
179+
"""Quantize a model and save it to disk."""
180+
# Get quantization config
181+
quantization_config = self.get_quantization_config(quant_type, granularity)
182+
183+
# Load and quantize model
184+
print(f"Loading and quantizing model with {quant_type}...")
185+
quantized_model = AutoModelForCausalLM.from_pretrained(
186+
model_name,
187+
torch_dtype="bfloat16",
188+
device_map="cuda",
189+
quantization_config=quantization_config,
190+
)
191+
192+
# Load tokenizer
193+
tokenizer = AutoTokenizer.from_pretrained(model_name)
194+
195+
# Quick test generation to verify model works
196+
test_input = "Hello, world!"
197+
input_ids = tokenizer(test_input, return_tensors="pt").to(
198+
quantized_model.device
199+
)
200+
201+
with torch.no_grad():
202+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
203+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
204+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
205+
206+
# Save quantized model
207+
print(f"Saving quantized model to {output_dir}...")
208+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
209+
tokenizer.save_pretrained(output_dir)
210+
211+
# Clean up to free memory
212+
del quantized_model
213+
torch.cuda.empty_cache()
214+
215+
return output_dir
216+
217+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
218+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
219+
@pytest.mark.parametrize("quant_type,granularity", get_tests())
220+
@pytest.mark.parametrize("compile", [True, False])
221+
@pytest.mark.parametrize(
222+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
223+
)
224+
def test_vllm_smoke_test(self, tmp_path, quant_type, granularity, compile, tp_size):
225+
"""Test vLLM generation with quantized models."""
226+
# Skip per_row tests if not supported
227+
torch._dynamo.reset()
228+
if granularity == "per_row" and not torch.cuda.get_device_capability()[0] >= 9:
229+
pytest.skip("Per-row quantization requires SM90+")
230+
231+
# Use a small model for testing
232+
base_model = "facebook/opt-125m"
233+
234+
# Quantize the model
235+
output_dir = tmp_path / f"{quant_type}-{granularity}-opt-125m"
236+
quantized_model_path = self.quantize_and_save_model(
237+
base_model, quant_type, output_dir, granularity
238+
)
239+
240+
# Test generation with vLLM
241+
sampling_params = SamplingParams(
242+
temperature=0.8,
243+
top_p=0.95,
244+
seed=42,
245+
max_tokens=16, # Small for testing
246+
)
247+
248+
# Create LLM instance
249+
llm = LLM(
250+
model=str(quantized_model_path),
251+
tensor_parallel_size=tp_size,
252+
enforce_eager=not compile,
253+
dtype="bfloat16",
254+
num_gpu_blocks_override=128,
255+
)
256+
257+
# Test prompts
258+
prompts = [
259+
"Hello, my name is",
260+
"The capital of France is",
261+
]
262+
263+
# Generate outputs
264+
outputs = llm.generate(prompts, sampling_params)
265+
266+
# Verify outputs
267+
assert len(outputs) == len(prompts)
268+
for output in outputs:
269+
assert output.prompt in prompts
270+
assert len(output.outputs) > 0
271+
generated_text = output.outputs[0].text
272+
assert isinstance(generated_text, str)
273+
assert len(generated_text) > 0
274+
275+
# Clean up
276+
del llm
277+
torch.cuda.empty_cache()
278+
279+
280+
if __name__ == "__main__":
281+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)