|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import importlib.util |
| 8 | +import os |
| 9 | +import random |
| 10 | +from pathlib import Path |
| 11 | +from typing import List |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import pytest |
| 15 | +import torch |
| 16 | + |
| 17 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_7 |
| 18 | + |
| 19 | +if not TORCH_VERSION_AT_LEAST_2_7: |
| 20 | + pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True) |
| 21 | + |
| 22 | + |
| 23 | +VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None |
| 24 | +TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None |
| 25 | + |
| 26 | +if not VLLM_AVAILABLE: |
| 27 | + pytest.skip("vLLM not installed", allow_module_level=True) |
| 28 | + |
| 29 | +if not TRANSFORMERS_AVAILABLE: |
| 30 | + pytest.skip("transformers not installed", allow_module_level=True) |
| 31 | + |
| 32 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 33 | +from vllm import LLM, SamplingParams |
| 34 | + |
| 35 | +from torchao.quantization.granularity import PerRow, PerTensor |
| 36 | +from torchao.quantization.quant_api import ( |
| 37 | + CutlassInt4PackedLayout, |
| 38 | + Float8DynamicActivationFloat8WeightConfig, |
| 39 | + Int8DynamicActivationInt4WeightConfig, |
| 40 | + Int8WeightOnlyConfig, |
| 41 | +) |
| 42 | + |
| 43 | + |
| 44 | +def get_tests() -> List[TorchAoConfig]: |
| 45 | + """Get all the tests based off of device info""" |
| 46 | + |
| 47 | + # Helper objects for granularity |
| 48 | + per_tensor = PerTensor() |
| 49 | + per_row = PerRow() |
| 50 | + |
| 51 | + BASE_TESTS = [TorchAoConfig(Int8WeightOnlyConfig())] |
| 52 | + SM89_TESTS = [ |
| 53 | + TorchAoConfig( |
| 54 | + Float8DynamicActivationFloat8WeightConfig(granularity=per_tensor) |
| 55 | + ), |
| 56 | + TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(granularity=per_row)), |
| 57 | + ] |
| 58 | + SM90_ONLY_TESTS = [ |
| 59 | + TorchAoConfig( |
| 60 | + Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout()) |
| 61 | + ) |
| 62 | + ] |
| 63 | + SM100_TESTS = [ |
| 64 | + # TorchAoConfig(MXFPInferenceConfig()) |
| 65 | + ] # Failing for : https://github.com/pytorch/ao/issues/2239 |
| 66 | + |
| 67 | + # Check CUDA availability first |
| 68 | + if not torch.cuda.is_available(): |
| 69 | + return [] # No CUDA, no tests |
| 70 | + |
| 71 | + major, minor = torch.cuda.get_device_capability() |
| 72 | + |
| 73 | + # Build test list based on compute capability |
| 74 | + all_tests = [] |
| 75 | + |
| 76 | + # Always include base tests if we have CUDA |
| 77 | + all_tests.extend(BASE_TESTS) |
| 78 | + |
| 79 | + # Add SM89+ tests |
| 80 | + if major > 8 or (major == 8 and minor >= 9): |
| 81 | + all_tests.extend(SM89_TESTS) |
| 82 | + |
| 83 | + # Add SM100+ tests |
| 84 | + if major >= 10: |
| 85 | + all_tests.extend(SM100_TESTS) |
| 86 | + |
| 87 | + # Only work for sm 90 |
| 88 | + if major == 9: |
| 89 | + all_tests.extend(SM90_ONLY_TESTS) |
| 90 | + |
| 91 | + return all_tests |
| 92 | + |
| 93 | + |
| 94 | +class TestVLLMIntegration: |
| 95 | + """Integration tests for vLLM with quantized models.""" |
| 96 | + |
| 97 | + @classmethod |
| 98 | + def setup_class(cls): |
| 99 | + """Set up test environment.""" |
| 100 | + # Set seeds for reproducibility |
| 101 | + cls.set_seed(42) |
| 102 | + |
| 103 | + # Set vLLM environment variables |
| 104 | + os.environ["VLLM_USE_V1"] = "1" |
| 105 | + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" |
| 106 | + os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1" |
| 107 | + |
| 108 | + @classmethod |
| 109 | + def teardown_class(cls): |
| 110 | + """Clean up after all tests.""" |
| 111 | + torch.cuda.empty_cache() |
| 112 | + import gc |
| 113 | + |
| 114 | + gc.collect() |
| 115 | + |
| 116 | + def setup_method(self, method): |
| 117 | + """Clean up before each test method.""" |
| 118 | + torch.cuda.empty_cache() |
| 119 | + import gc |
| 120 | + |
| 121 | + gc.collect() |
| 122 | + |
| 123 | + def teardown_method(self, method): |
| 124 | + """Clean up after each test method.""" |
| 125 | + torch.cuda.empty_cache() |
| 126 | + import gc |
| 127 | + |
| 128 | + gc.collect() |
| 129 | + |
| 130 | + @staticmethod |
| 131 | + def set_seed(seed): |
| 132 | + """Set random seeds for reproducibility.""" |
| 133 | + random.seed(seed) |
| 134 | + np.random.seed(seed) |
| 135 | + torch.manual_seed(seed) |
| 136 | + torch.cuda.manual_seed_all(seed) |
| 137 | + |
| 138 | + def quantize_and_save_model( |
| 139 | + self, |
| 140 | + model_name: str, |
| 141 | + quantization_config: TorchAoConfig, |
| 142 | + output_dir: Path, |
| 143 | + ): |
| 144 | + """Quantize a model and save it to disk.""" |
| 145 | + # Load and quantize model |
| 146 | + quantized_model = AutoModelForCausalLM.from_pretrained( |
| 147 | + model_name, |
| 148 | + torch_dtype="bfloat16", |
| 149 | + device_map="cuda", |
| 150 | + quantization_config=quantization_config, |
| 151 | + ) |
| 152 | + |
| 153 | + # Load tokenizer |
| 154 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 155 | + |
| 156 | + # Quick test generation to verify model works |
| 157 | + test_input = "Hello, world!" |
| 158 | + input_ids = tokenizer(test_input, return_tensors="pt").to( |
| 159 | + quantized_model.device |
| 160 | + ) |
| 161 | + |
| 162 | + with torch.no_grad(): |
| 163 | + output = quantized_model.generate(**input_ids, max_new_tokens=5) |
| 164 | + decoded = tokenizer.decode(output[0], skip_special_tokens=True) |
| 165 | + print(f"Quick test - Input: {test_input}, Output: {decoded}") |
| 166 | + |
| 167 | + # Save quantized model |
| 168 | + print(f"Saving quantized model to {output_dir}...") |
| 169 | + quantized_model.save_pretrained(output_dir, safe_serialization=False) |
| 170 | + tokenizer.save_pretrained(output_dir) |
| 171 | + |
| 172 | + # Clean up to free memory |
| 173 | + del quantized_model |
| 174 | + torch.cuda.empty_cache() |
| 175 | + |
| 176 | + return output_dir |
| 177 | + |
| 178 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 179 | + @pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed") |
| 180 | + @pytest.mark.parametrize( |
| 181 | + "quantization_config", get_tests(), ids=lambda config: f"{config.quant_type}" |
| 182 | + ) |
| 183 | + @pytest.mark.parametrize("compile", [True, False]) |
| 184 | + @pytest.mark.parametrize( |
| 185 | + "tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1] |
| 186 | + ) |
| 187 | + def test_vllm_smoke_test(self, tmp_path, quantization_config, compile, tp_size): |
| 188 | + """Test vLLM generation with quantized models.""" |
| 189 | + # Skip per_row tests if not supported |
| 190 | + torch._dynamo.reset() |
| 191 | + |
| 192 | + # Use a small model for testing |
| 193 | + base_model = "facebook/opt-125m" |
| 194 | + |
| 195 | + # Create a descriptive name for the output directory |
| 196 | + config_name = str(quantization_config).replace("/", "_").replace(" ", "_")[:50] |
| 197 | + output_dir = tmp_path / f"{config_name}-opt-125m" |
| 198 | + |
| 199 | + # Quantize the model |
| 200 | + quantized_model_path = self.quantize_and_save_model( |
| 201 | + base_model, quantization_config, output_dir |
| 202 | + ) |
| 203 | + |
| 204 | + # Test generation with vLLM |
| 205 | + sampling_params = SamplingParams( |
| 206 | + temperature=0.8, |
| 207 | + top_p=0.95, |
| 208 | + seed=42, |
| 209 | + max_tokens=16, # Small for testing |
| 210 | + ) |
| 211 | + |
| 212 | + # Create LLM instance |
| 213 | + llm = LLM( |
| 214 | + model=str(quantized_model_path), |
| 215 | + tensor_parallel_size=tp_size, |
| 216 | + enforce_eager=not compile, |
| 217 | + dtype="bfloat16", |
| 218 | + num_gpu_blocks_override=128, |
| 219 | + ) |
| 220 | + |
| 221 | + # Test prompts |
| 222 | + prompts = [ |
| 223 | + "Hello, my name is", |
| 224 | + "The capital of France is", |
| 225 | + ] |
| 226 | + |
| 227 | + # Generate outputs |
| 228 | + outputs = llm.generate(prompts, sampling_params) |
| 229 | + |
| 230 | + # Verify outputs |
| 231 | + assert len(outputs) == len(prompts) |
| 232 | + for output in outputs: |
| 233 | + assert output.prompt in prompts |
| 234 | + assert len(output.outputs) > 0 |
| 235 | + generated_text = output.outputs[0].text |
| 236 | + assert isinstance(generated_text, str) |
| 237 | + assert len(generated_text) > 0 |
| 238 | + |
| 239 | + # Clean up |
| 240 | + del llm |
| 241 | + torch.cuda.empty_cache() |
| 242 | + |
| 243 | + |
| 244 | +if __name__ == "__main__": |
| 245 | + pytest.main([__file__, "-v"]) |
0 commit comments