|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import random |
| 5 | +import typing |
| 6 | + |
| 7 | +import pytest |
| 8 | +import ray |
| 9 | +import torch |
| 10 | +import torch.distributed as dist |
| 11 | + |
| 12 | +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce |
| 13 | +from vllm.distributed.device_communicators.cuda_communicator import ( |
| 14 | + CudaCommunicator) |
| 15 | +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, |
| 16 | + get_tp_group) |
| 17 | +from vllm.platforms import current_platform |
| 18 | + |
| 19 | +from ..utils import (ensure_model_parallel_initialized, |
| 20 | + init_test_distributed_environment, multi_process_parallel) |
| 21 | + |
| 22 | +torch.manual_seed(42) |
| 23 | +random.seed(44) |
| 24 | + |
| 25 | +test_size_elements = 4 * 1024 * 1024 |
| 26 | + |
| 27 | + |
| 28 | +@ray.remote(num_gpus=1, max_calls=1) |
| 29 | +def symm_mem_allreduce( |
| 30 | + monkeypatch: pytest.MonkeyPatch, |
| 31 | + tp_size, |
| 32 | + pp_size, |
| 33 | + rank, |
| 34 | + distributed_init_port, |
| 35 | +): |
| 36 | + with monkeypatch.context() as m: |
| 37 | + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) |
| 38 | + device = torch.device(f"cuda:{rank}") |
| 39 | + torch.cuda.set_device(device) |
| 40 | + |
| 41 | + init_test_distributed_environment(tp_size, pp_size, rank, |
| 42 | + distributed_init_port) |
| 43 | + ensure_model_parallel_initialized(tp_size, pp_size) |
| 44 | + |
| 45 | + dtype = torch.bfloat16 |
| 46 | + |
| 47 | + cuda_communicator = typing.cast(CudaCommunicator, |
| 48 | + get_tp_group().device_communicator) |
| 49 | + symm_mem_comm = cuda_communicator.symm_mem_comm |
| 50 | + if symm_mem_comm is None or symm_mem_comm.disabled: |
| 51 | + pytest.skip("SymmMemCommunicator is not available or disabled.") |
| 52 | + |
| 53 | + inp_direct_symm_mem = torch.randint(1, |
| 54 | + 23, (test_size_elements, ), |
| 55 | + dtype=dtype, |
| 56 | + device=device) |
| 57 | + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): |
| 58 | + pytest.skip( |
| 59 | + "SymmMemCommunicator isn't used for this world and input size." |
| 60 | + ) |
| 61 | + |
| 62 | + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() |
| 63 | + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) |
| 64 | + assert out_direct_symm_mem is not None |
| 65 | + |
| 66 | + group = get_tensor_model_parallel_group().device_group |
| 67 | + dist.all_reduce(original_inp_direct_symm_mem, group=group) |
| 68 | + torch.testing.assert_close(out_direct_symm_mem, |
| 69 | + original_inp_direct_symm_mem, |
| 70 | + atol=2.5, |
| 71 | + rtol=0.1) |
| 72 | + |
| 73 | + # Test tensor_model_parallel_all_reduce which should use symm_mem |
| 74 | + inp_tensor_parallel = torch.randint(-23, |
| 75 | + 1, (test_size_elements, ), |
| 76 | + dtype=dtype, |
| 77 | + device=device) |
| 78 | + original_inp_tensor_parallel = inp_tensor_parallel.clone() |
| 79 | + out_tensor_parallel = tensor_model_parallel_all_reduce( |
| 80 | + inp_tensor_parallel) |
| 81 | + dist.all_reduce(original_inp_tensor_parallel, group=group) |
| 82 | + torch.testing.assert_close(out_tensor_parallel, |
| 83 | + original_inp_tensor_parallel, |
| 84 | + atol=2.5, |
| 85 | + rtol=0.1) |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.skipif( |
| 89 | + not current_platform.is_cuda(), |
| 90 | + reason="SymmMemAllreduce is only available for CUDA platforms.") |
| 91 | +@pytest.mark.parametrize("tp_size", [2, 4]) |
| 92 | +@pytest.mark.parametrize("pipeline_parallel_size", [1]) |
| 93 | +@pytest.mark.parametrize("test_target", [symm_mem_allreduce]) |
| 94 | +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, |
| 95 | + pipeline_parallel_size, test_target): |
| 96 | + world_size = tp_size * pipeline_parallel_size |
| 97 | + if world_size > torch.cuda.device_count(): |
| 98 | + pytest.skip("Not enough GPUs to run the test.") |
| 99 | + |
| 100 | + # Enable SymmMemCommunicator |
| 101 | + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") |
| 102 | + |
| 103 | + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, |
| 104 | + test_target) |
0 commit comments