|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import pytest
|
| 12 | +import torch |
| 13 | +import torch.multiprocessing as mp |
12 | 14 | from PIL import Image, ImageChops
|
13 | 15 |
|
| 16 | +from tests.utils import multi_gpu_test |
| 17 | +from vllm.distributed import get_tensor_model_parallel_world_size |
| 18 | +from vllm.distributed.parallel_state import (init_distributed_environment, |
| 19 | + initialize_model_parallel) |
14 | 20 | from vllm.multimodal.image import convert_image_mode
|
15 | 21 | from vllm.multimodal.inputs import PlaceholderRange
|
16 | 22 | from vllm.multimodal.utils import (MediaConnector,
|
17 |
| - merge_and_sort_multimodal_metadata) |
| 23 | + merge_and_sort_multimodal_metadata, |
| 24 | + run_dp_sharded_vision_model) |
| 25 | +from vllm.platforms import current_platform |
| 26 | +from vllm.utils import get_open_port, update_environment_variables |
18 | 27 |
|
19 | 28 | if TYPE_CHECKING:
|
20 | 29 | from vllm.multimodal.hasher import MultiModalHashDict
|
@@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
413 | 422 | assert modalities == expected_modalities
|
414 | 423 | assert ranges == expected_ranges
|
415 | 424 | assert hashes == expected_hashes
|
| 425 | + |
| 426 | + |
| 427 | +class SimpleLinearModel(torch.nn.Module): |
| 428 | + """A simple linear vision model for testing.""" |
| 429 | + |
| 430 | + def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32): |
| 431 | + super().__init__() |
| 432 | + self.flatten = torch.nn.Flatten() |
| 433 | + self.linear = torch.nn.Linear(input_dim, output_dim) |
| 434 | + |
| 435 | + def forward(self, x: torch.Tensor): |
| 436 | + # Flatten the input and apply linear transformation |
| 437 | + x = self.flatten(x) |
| 438 | + return self.linear(x) |
| 439 | + |
| 440 | + |
| 441 | +@multi_gpu_test(num_gpus=2) |
| 442 | +@pytest.mark.parametrize( |
| 443 | + "batch_size", |
| 444 | + [ |
| 445 | + 1, # Single image |
| 446 | + 4, # Small batch |
| 447 | + 5, # Odd batch size (for testing padding) |
| 448 | + ], |
| 449 | +) |
| 450 | +def test_run_dp_sharded_vision_model(batch_size: int): |
| 451 | + world_size = 2 |
| 452 | + # Launch processes |
| 453 | + mp.spawn( |
| 454 | + run_dp_sharded_vision_model_vs_direct, |
| 455 | + args=( |
| 456 | + world_size, |
| 457 | + batch_size, |
| 458 | + get_open_port(), |
| 459 | + ), |
| 460 | + nprocs=world_size, |
| 461 | + ) |
| 462 | + |
| 463 | + |
| 464 | +def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, |
| 465 | + batch_size: int, master_port: int): |
| 466 | + """ |
| 467 | + Test that run_dp_sharded_vision_model produces the same results as |
| 468 | + calling the model directly. |
| 469 | + """ |
| 470 | + |
| 471 | + # Set random seed for reproducibility |
| 472 | + current_platform.seed_everything(0) |
| 473 | + |
| 474 | + device = torch.device(f"cuda:{local_rank}") |
| 475 | + torch.cuda.set_device(device) |
| 476 | + torch.set_default_device(device) |
| 477 | + |
| 478 | + update_environment_variables({ |
| 479 | + 'RANK': str(local_rank), |
| 480 | + 'LOCAL_RANK': str(local_rank), |
| 481 | + 'WORLD_SIZE': str(world_size), |
| 482 | + 'MASTER_ADDR': 'localhost', |
| 483 | + 'MASTER_PORT': str(master_port), |
| 484 | + }) |
| 485 | + |
| 486 | + # initialize distributed |
| 487 | + init_distributed_environment() |
| 488 | + initialize_model_parallel(tensor_model_parallel_size=world_size) |
| 489 | + |
| 490 | + # Create a test input tensor |
| 491 | + image_input = torch.randn(batch_size, 3, 224, 224) |
| 492 | + |
| 493 | + # Create a simple linear model |
| 494 | + vision_model = SimpleLinearModel() |
| 495 | + |
| 496 | + # Run the model directly on the full input |
| 497 | + with torch.inference_mode(): |
| 498 | + direct_output = vision_model(image_input) |
| 499 | + |
| 500 | + # Run the model through the sharded function |
| 501 | + with torch.inference_mode(): |
| 502 | + sharded_output = run_dp_sharded_vision_model(image_input, vision_model) |
| 503 | + |
| 504 | + # Check that the world size is setup correctly |
| 505 | + assert get_tensor_model_parallel_world_size() == world_size |
| 506 | + |
| 507 | + # Check that the outputs have the same shape |
| 508 | + assert direct_output.shape == sharded_output.shape |
| 509 | + |
| 510 | + # Check that the outputs are close (they should be identical) |
| 511 | + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) |
0 commit comments