Skip to content

Commit 3066141

Browse files
committed
integration-vllm-test
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 1017c7e commit 3066141

File tree

1 file changed

+245
-0
lines changed

1 file changed

+245
-0
lines changed

test/integration/test_vllm.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
from pathlib import Path
11+
from typing import List
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
18+
19+
if not TORCH_VERSION_AT_LEAST_2_7:
20+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
21+
22+
23+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
24+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
25+
26+
if not VLLM_AVAILABLE:
27+
pytest.skip("vLLM not installed", allow_module_level=True)
28+
29+
if not TRANSFORMERS_AVAILABLE:
30+
pytest.skip("transformers not installed", allow_module_level=True)
31+
32+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
33+
from vllm import LLM, SamplingParams
34+
35+
from torchao.quantization.granularity import PerRow, PerTensor
36+
from torchao.quantization.quant_api import (
37+
CutlassInt4PackedLayout,
38+
Float8DynamicActivationFloat8WeightConfig,
39+
Int8DynamicActivationInt4WeightConfig,
40+
Int8WeightOnlyConfig,
41+
)
42+
43+
44+
def get_tests() -> List[TorchAoConfig]:
45+
"""Get all the tests based off of device info"""
46+
47+
# Helper objects for granularity
48+
per_tensor = PerTensor()
49+
per_row = PerRow()
50+
51+
BASE_TESTS = [TorchAoConfig(Int8WeightOnlyConfig())]
52+
SM89_TESTS = [
53+
TorchAoConfig(
54+
Float8DynamicActivationFloat8WeightConfig(granularity=per_tensor)
55+
),
56+
TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(granularity=per_row)),
57+
]
58+
SM90_ONLY_TESTS = [
59+
TorchAoConfig(
60+
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
61+
)
62+
]
63+
SM100_TESTS = [
64+
# TorchAoConfig(MXFPInferenceConfig())
65+
] # Failing for : https://github.com/pytorch/ao/issues/2239
66+
67+
# Check CUDA availability first
68+
if not torch.cuda.is_available():
69+
return [] # No CUDA, no tests
70+
71+
major, minor = torch.cuda.get_device_capability()
72+
73+
# Build test list based on compute capability
74+
all_tests = []
75+
76+
# Always include base tests if we have CUDA
77+
all_tests.extend(BASE_TESTS)
78+
79+
# Add SM89+ tests
80+
if major > 8 or (major == 8 and minor >= 9):
81+
all_tests.extend(SM89_TESTS)
82+
83+
# Add SM100+ tests
84+
if major >= 10:
85+
all_tests.extend(SM100_TESTS)
86+
87+
# Only work for sm 90
88+
if major == 9:
89+
all_tests.extend(SM90_ONLY_TESTS)
90+
91+
return all_tests
92+
93+
94+
class TestVLLMIntegration:
95+
"""Integration tests for vLLM with quantized models."""
96+
97+
@classmethod
98+
def setup_class(cls):
99+
"""Set up test environment."""
100+
# Set seeds for reproducibility
101+
cls.set_seed(42)
102+
103+
# Set vLLM environment variables
104+
os.environ["VLLM_USE_V1"] = "1"
105+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
106+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
107+
108+
@classmethod
109+
def teardown_class(cls):
110+
"""Clean up after all tests."""
111+
torch.cuda.empty_cache()
112+
import gc
113+
114+
gc.collect()
115+
116+
def setup_method(self, method):
117+
"""Clean up before each test method."""
118+
torch.cuda.empty_cache()
119+
import gc
120+
121+
gc.collect()
122+
123+
def teardown_method(self, method):
124+
"""Clean up after each test method."""
125+
torch.cuda.empty_cache()
126+
import gc
127+
128+
gc.collect()
129+
130+
@staticmethod
131+
def set_seed(seed):
132+
"""Set random seeds for reproducibility."""
133+
random.seed(seed)
134+
np.random.seed(seed)
135+
torch.manual_seed(seed)
136+
torch.cuda.manual_seed_all(seed)
137+
138+
def quantize_and_save_model(
139+
self,
140+
model_name: str,
141+
quantization_config: TorchAoConfig,
142+
output_dir: Path,
143+
):
144+
"""Quantize a model and save it to disk."""
145+
# Load and quantize model
146+
quantized_model = AutoModelForCausalLM.from_pretrained(
147+
model_name,
148+
torch_dtype="bfloat16",
149+
device_map="cuda",
150+
quantization_config=quantization_config,
151+
)
152+
153+
# Load tokenizer
154+
tokenizer = AutoTokenizer.from_pretrained(model_name)
155+
156+
# Quick test generation to verify model works
157+
test_input = "Hello, world!"
158+
input_ids = tokenizer(test_input, return_tensors="pt").to(
159+
quantized_model.device
160+
)
161+
162+
with torch.no_grad():
163+
output = quantized_model.generate(**input_ids, max_new_tokens=5)
164+
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
165+
print(f"Quick test - Input: {test_input}, Output: {decoded}")
166+
167+
# Save quantized model
168+
print(f"Saving quantized model to {output_dir}...")
169+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
170+
tokenizer.save_pretrained(output_dir)
171+
172+
# Clean up to free memory
173+
del quantized_model
174+
torch.cuda.empty_cache()
175+
176+
return output_dir
177+
178+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
179+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
180+
@pytest.mark.parametrize(
181+
"quantization_config", get_tests(), ids=lambda config: f"{config.quant_type}"
182+
)
183+
@pytest.mark.parametrize("compile", [True, False])
184+
@pytest.mark.parametrize(
185+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
186+
)
187+
def test_vllm_smoke_test(self, tmp_path, quantization_config, compile, tp_size):
188+
"""Test vLLM generation with quantized models."""
189+
# Skip per_row tests if not supported
190+
torch._dynamo.reset()
191+
192+
# Use a small model for testing
193+
base_model = "facebook/opt-125m"
194+
195+
# Create a descriptive name for the output directory
196+
config_name = str(quantization_config).replace("/", "_").replace(" ", "_")[:50]
197+
output_dir = tmp_path / f"{config_name}-opt-125m"
198+
199+
# Quantize the model
200+
quantized_model_path = self.quantize_and_save_model(
201+
base_model, quantization_config, output_dir
202+
)
203+
204+
# Test generation with vLLM
205+
sampling_params = SamplingParams(
206+
temperature=0.8,
207+
top_p=0.95,
208+
seed=42,
209+
max_tokens=16, # Small for testing
210+
)
211+
212+
# Create LLM instance
213+
llm = LLM(
214+
model=str(quantized_model_path),
215+
tensor_parallel_size=tp_size,
216+
enforce_eager=not compile,
217+
dtype="bfloat16",
218+
num_gpu_blocks_override=128,
219+
)
220+
221+
# Test prompts
222+
prompts = [
223+
"Hello, my name is",
224+
"The capital of France is",
225+
]
226+
227+
# Generate outputs
228+
outputs = llm.generate(prompts, sampling_params)
229+
230+
# Verify outputs
231+
assert len(outputs) == len(prompts)
232+
for output in outputs:
233+
assert output.prompt in prompts
234+
assert len(output.outputs) > 0
235+
generated_text = output.outputs[0].text
236+
assert isinstance(generated_text, str)
237+
assert len(generated_text) > 0
238+
239+
# Clean up
240+
del llm
241+
torch.cuda.empty_cache()
242+
243+
244+
if __name__ == "__main__":
245+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)