Skip to content

Commit f168b85

Browse files
cryptopicSiqi Yan
andauthored
Unit Test for run_dp_sharded_vision_model (#19103)
Signed-off-by: Siqi Yan <siqi@meta.com> Co-authored-by: Siqi Yan <siqi@meta.com>
1 parent da511d5 commit f168b85

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

tests/multimodal/test_utils.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99

1010
import numpy as np
1111
import pytest
12+
import torch
13+
import torch.multiprocessing as mp
1214
from PIL import Image, ImageChops
1315

16+
from tests.utils import multi_gpu_test
17+
from vllm.distributed import get_tensor_model_parallel_world_size
18+
from vllm.distributed.parallel_state import (init_distributed_environment,
19+
initialize_model_parallel)
1420
from vllm.multimodal.image import convert_image_mode
1521
from vllm.multimodal.inputs import PlaceholderRange
1622
from vllm.multimodal.utils import (MediaConnector,
17-
merge_and_sort_multimodal_metadata)
23+
merge_and_sort_multimodal_metadata,
24+
run_dp_sharded_vision_model)
25+
from vllm.platforms import current_platform
26+
from vllm.utils import get_open_port, update_environment_variables
1827

1928
if TYPE_CHECKING:
2029
from vllm.multimodal.hasher import MultiModalHashDict
@@ -413,3 +422,90 @@ def test_merge_and_sort_multimodal_metadata_with_interleaving():
413422
assert modalities == expected_modalities
414423
assert ranges == expected_ranges
415424
assert hashes == expected_hashes
425+
426+
427+
class SimpleLinearModel(torch.nn.Module):
428+
"""A simple linear vision model for testing."""
429+
430+
def __init__(self, input_dim: int = 3 * 224 * 224, output_dim: int = 32):
431+
super().__init__()
432+
self.flatten = torch.nn.Flatten()
433+
self.linear = torch.nn.Linear(input_dim, output_dim)
434+
435+
def forward(self, x: torch.Tensor):
436+
# Flatten the input and apply linear transformation
437+
x = self.flatten(x)
438+
return self.linear(x)
439+
440+
441+
@multi_gpu_test(num_gpus=2)
442+
@pytest.mark.parametrize(
443+
"batch_size",
444+
[
445+
1, # Single image
446+
4, # Small batch
447+
5, # Odd batch size (for testing padding)
448+
],
449+
)
450+
def test_run_dp_sharded_vision_model(batch_size: int):
451+
world_size = 2
452+
# Launch processes
453+
mp.spawn(
454+
run_dp_sharded_vision_model_vs_direct,
455+
args=(
456+
world_size,
457+
batch_size,
458+
get_open_port(),
459+
),
460+
nprocs=world_size,
461+
)
462+
463+
464+
def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int,
465+
batch_size: int, master_port: int):
466+
"""
467+
Test that run_dp_sharded_vision_model produces the same results as
468+
calling the model directly.
469+
"""
470+
471+
# Set random seed for reproducibility
472+
current_platform.seed_everything(0)
473+
474+
device = torch.device(f"cuda:{local_rank}")
475+
torch.cuda.set_device(device)
476+
torch.set_default_device(device)
477+
478+
update_environment_variables({
479+
'RANK': str(local_rank),
480+
'LOCAL_RANK': str(local_rank),
481+
'WORLD_SIZE': str(world_size),
482+
'MASTER_ADDR': 'localhost',
483+
'MASTER_PORT': str(master_port),
484+
})
485+
486+
# initialize distributed
487+
init_distributed_environment()
488+
initialize_model_parallel(tensor_model_parallel_size=world_size)
489+
490+
# Create a test input tensor
491+
image_input = torch.randn(batch_size, 3, 224, 224)
492+
493+
# Create a simple linear model
494+
vision_model = SimpleLinearModel()
495+
496+
# Run the model directly on the full input
497+
with torch.inference_mode():
498+
direct_output = vision_model(image_input)
499+
500+
# Run the model through the sharded function
501+
with torch.inference_mode():
502+
sharded_output = run_dp_sharded_vision_model(image_input, vision_model)
503+
504+
# Check that the world size is setup correctly
505+
assert get_tensor_model_parallel_world_size() == world_size
506+
507+
# Check that the outputs have the same shape
508+
assert direct_output.shape == sharded_output.shape
509+
510+
# Check that the outputs are close (they should be identical)
511+
assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5)

0 commit comments

Comments
 (0)