Skip to content

Commit 9112b44

Browse files
lsy323hosseinsarsharyaochengji
authored
[Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@google.com>
1 parent c57d577 commit 9112b44

File tree

11 files changed

+605
-72
lines changed

11 files changed

+605
-72
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ run_and_track_test 12 "test_moe_pallas.py" \
155155
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
156156
run_and_track_test 13 "test_lora.py" \
157157
"VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py"
158+
run_and_track_test 14 "test_tpu_qkv_linear.py" \
159+
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
160+
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
161+
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
158162
159163
# After all tests have been attempted, exit with the overall status.
160164
if [ "$overall_script_exit_code" -ne 0 ]; then

examples/offline_inference/tpu.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import argparse
4+
import os
5+
36
from vllm import LLM, SamplingParams
47

58
prompts = [
@@ -18,14 +21,28 @@
1821

1922

2023
def main():
24+
parser = argparse.ArgumentParser(description="TPU offline inference example")
25+
parser.add_argument("--use-spmd", action="store_true", help="Enable SPMD mode")
26+
args = parser.parse_args()
27+
28+
llm_args = {
29+
"model": "Qwen/Qwen2-1.5B-Instruct",
30+
"max_num_batched_tokens": 64,
31+
"max_num_seqs": 4,
32+
"max_model_len": 128,
33+
}
34+
if args.use_spmd:
35+
os.environ["VLLM_XLA_USE_SPMD"] = "1"
36+
# Can only hardcode the number of chips for now.
37+
# calling xr.global_runtime_device_count() beforeing init SPMD env in
38+
# torch_xla will mess up the distributed env.
39+
llm_args["tensor_parallel_size"] = 8
40+
# Use Llama, for num_kv_heads = 8.
41+
llm_args["model"] = "meta-llama/Llama-3.1-8B-Instruct"
42+
2143
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
2244
# In real workloads, `enforace_eager` should be `False`.
23-
llm = LLM(
24-
model="Qwen/Qwen2-1.5B-Instruct",
25-
max_num_batched_tokens=64,
26-
max_num_seqs=4,
27-
max_model_len=128,
28-
)
45+
llm = LLM(**llm_args)
2946
outputs = llm.generate(prompts, sampling_params)
3047
print("-" * 50)
3148
for output, answer in zip(outputs, answers):
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import gc
3+
import tempfile
4+
5+
import numpy as np
6+
import pytest
7+
import torch_xla.distributed.spmd as xs
8+
import torch_xla.runtime as xr
9+
10+
from vllm.config import set_current_vllm_config
11+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
12+
init_distributed_environment)
13+
from vllm.engine.arg_utils import EngineArgs
14+
from vllm.model_executor.model_loader.tpu import TPUModelLoader
15+
16+
17+
def _setup_environment(model):
18+
engine_args = EngineArgs(model=model, )
19+
vllm_config = engine_args.create_engine_config()
20+
with set_current_vllm_config(vllm_config):
21+
temp_file = tempfile.mkstemp()[1]
22+
init_distributed_environment(
23+
1,
24+
0,
25+
local_rank=0,
26+
distributed_init_method=f"file://{temp_file}",
27+
backend="gloo")
28+
# Under single worker mode, full model is init first and then
29+
# partitioned using GSPMD.
30+
ensure_model_parallel_initialized(1, 1)
31+
return vllm_config
32+
33+
34+
MESH = None
35+
36+
37+
def _get_spmd_mesh():
38+
global MESH
39+
if MESH is None:
40+
xr.use_spmd()
41+
num_devices = xr.global_runtime_device_count()
42+
mesh_shape = (num_devices, 1)
43+
device_ids = np.array(range(num_devices))
44+
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
45+
return MESH
46+
47+
48+
@pytest.mark.parametrize("model", [
49+
"Qwen/Qwen2-1.5B-Instruct",
50+
"meta-llama/Llama-3.1-8B-Instruct",
51+
"meta-llama/Llama-3.1-70B-Instruct",
52+
])
53+
def test_tpu_model_loader(model):
54+
# Skip the 70B test if there are less than 8 chips
55+
# TODO: Query using torch xla API, the query API is not working
56+
# with SPMD now. However, This test is running under SPMD mode.
57+
if '70B' in model and xr.global_runtime_device_count() < 8:
58+
pytest.skip(
59+
"Skipping 70B model if the TPU VM has less than 8 chips to \
60+
avoid OOM.")
61+
62+
vllm_config = _setup_environment(model)
63+
loader = TPUModelLoader(load_config=vllm_config.load_config)
64+
mesh = _get_spmd_mesh()
65+
model = loader.load_model(vllm_config, vllm_config.model_config, mesh)
66+
del model
67+
gc.collect()

tests/v1/tpu/test_tpu_qkv_linear.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import tempfile
3+
4+
import numpy as np
5+
import pytest
6+
import torch
7+
import torch_xla.distributed.spmd as xs
8+
import torch_xla.runtime as xr
9+
10+
from vllm.config import set_current_vllm_config
11+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
12+
init_distributed_environment)
13+
from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear
14+
from vllm.engine.arg_utils import EngineArgs
15+
from vllm.model_executor.layers.linear import QKVParallelLinear
16+
17+
18+
@pytest.fixture(autouse=True)
19+
def setup_environment():
20+
# This is a fake config used for init dist env.
21+
# QKVParallelLinear needs dist env to be initialized.
22+
engine_args = EngineArgs(
23+
model="Qwen/Qwen2-1.5B-Instruct",
24+
max_model_len=64,
25+
max_num_batched_tokens=64,
26+
max_num_seqs=4,
27+
)
28+
29+
vllm_config = engine_args.create_engine_config()
30+
31+
with set_current_vllm_config(vllm_config):
32+
temp_file = tempfile.mkstemp()[1]
33+
init_distributed_environment(
34+
1,
35+
0,
36+
local_rank=0,
37+
distributed_init_method=f"file://{temp_file}",
38+
backend="gloo")
39+
ensure_model_parallel_initialized(1, 1)
40+
yield
41+
42+
43+
MESH = None
44+
45+
46+
def _get_spmd_mesh():
47+
global MESH
48+
if MESH is None:
49+
xr.use_spmd()
50+
num_devices = xr.global_runtime_device_count()
51+
mesh_shape = (num_devices, 1)
52+
device_ids = np.array(range(num_devices))
53+
MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y'))
54+
return MESH
55+
56+
57+
@pytest.mark.parametrize("bias", [False, True])
58+
# `xr.use_spmd()` will set a global state, and this state is not reversible.
59+
# Therefore, non-SPMD tests should be run before SPMD tests.
60+
@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()])
61+
@pytest.mark.parametrize("device", ['cpu', 'xla'])
62+
@torch.no_grad()
63+
def test_xla_qkv_linear(bias, mesh, device):
64+
torch.manual_seed(123)
65+
66+
qkv_linear = QKVParallelLinear(
67+
hidden_size=4096,
68+
head_size=128,
69+
total_num_heads=32,
70+
total_num_kv_heads=8,
71+
bias=bias,
72+
params_dtype=torch.bfloat16,
73+
return_bias=False,
74+
)
75+
76+
qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10
77+
if bias:
78+
qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data)
79+
80+
xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh)
81+
82+
qkv_linear = qkv_linear.to(device)
83+
xla_qkv_linear = xla_qkv_linear.to(device)
84+
input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10
85+
input_tensor = input_tensor.to(device)
86+
87+
output = qkv_linear(input_tensor)
88+
xla_output = xla_qkv_linear(input_tensor)
89+
assert torch.allclose(output.cpu(), xla_output.cpu())

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,8 @@ def __post_init__(self) -> None:
19011901
if current_platform.is_neuron():
19021902
# neuron uses single process to control multiple devices
19031903
backend = "uni"
1904+
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
1905+
backend = "uni"
19041906
elif (current_platform.is_cuda()
19051907
and cuda_device_count_stateless() < self.world_size):
19061908
if not ray_found:
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from collections import OrderedDict
3+
from typing import Optional
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch_xla.distributed.spmd as xs
9+
from torch.nn.parameter import Parameter
10+
11+
from vllm.logger import init_logger
12+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
13+
QKVParallelLinear,
14+
RowParallelLinear)
15+
16+
logger = init_logger(__name__)
17+
18+
19+
class XlaQKVParallelLinear(nn.Module):
20+
21+
def __init__(self,
22+
qkv_linear: nn.Module,
23+
mesh: Optional["xs.Mesh"] = None):
24+
super().__init__()
25+
assert isinstance(qkv_linear, QKVParallelLinear)
26+
self.skip_bias_add = qkv_linear.skip_bias_add
27+
self.return_bias = qkv_linear.return_bias
28+
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
29+
30+
self.q_weight: Parameter
31+
self.k_weight: Parameter
32+
self.v_weight: Parameter
33+
self.q_bias: Optional[Parameter]
34+
self.k_bias: Optional[Parameter]
35+
self.v_bias: Optional[Parameter]
36+
self._load_weights_from_qkv_linear(qkv_linear)
37+
if mesh is not None:
38+
self._shard_weight(mesh)
39+
40+
def _shard_weight(self, mesh: "xs.Mesh"):
41+
self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False)
42+
self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False)
43+
self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False)
44+
xs.mark_sharding(self.q_weight, mesh, ('x', None))
45+
xs.mark_sharding(self.k_weight, mesh, ('x', None))
46+
xs.mark_sharding(self.v_weight, mesh, ('x', None))
47+
if self.q_bias is not None:
48+
assert self.k_bias is not None and self.v_bias is not None, \
49+
"QKVParallelLinear should have q, k, and v biases together."
50+
self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False)
51+
xs.mark_sharding(self.q_bias, mesh, ('x', ))
52+
self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False)
53+
xs.mark_sharding(self.k_bias, mesh, ('x', ))
54+
self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False)
55+
xs.mark_sharding(self.v_bias, mesh, ('x', ))
56+
57+
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
58+
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
59+
# The weight of qkv linear is a concatenation of q, k, and v weights
60+
# along the output dimension.
61+
qkv_weight = qkv_linear.weight.data.cpu()
62+
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
63+
k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size],
64+
requires_grad=False)
65+
v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:],
66+
requires_grad=False)
67+
self.register_parameter("q_weight", q_weight)
68+
self.register_parameter("k_weight", k_weight)
69+
self.register_parameter("v_weight", v_weight)
70+
71+
if qkv_linear.bias is not None:
72+
q_bias = Parameter(qkv_linear.bias[:q_proj_size],
73+
requires_grad=False)
74+
k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size +
75+
k_proj_size],
76+
requires_grad=False)
77+
v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:],
78+
requires_grad=False)
79+
self.register_parameter("q_bias", q_bias)
80+
self.register_parameter("k_bias", k_bias)
81+
self.register_parameter("v_bias", v_bias)
82+
else:
83+
self.register_parameter("q_bias", None)
84+
self.register_parameter("k_bias", None)
85+
self.register_parameter("v_bias", None)
86+
87+
def forward(self, input):
88+
# Same forward functionality as QKVParallelLinear, but doing qkv porj
89+
# separately.
90+
q_bias = self.q_bias if not self.skip_bias_add else None
91+
k_bias = self.k_bias if not self.skip_bias_add else None
92+
v_bias = self.v_bias if not self.skip_bias_add else None
93+
q_proj = F.linear(input, self.q_weight, q_bias)
94+
k_proj = F.linear(input, self.k_weight, k_bias)
95+
v_proj = F.linear(input, self.v_weight, v_bias)
96+
# The q/k/v projections will be split outside of the QKVParallelLinear.
97+
# Because we are replacing XlaQKVParallelLinear with the
98+
# QKVParallelLinear, we need to concatenate q, k, and v projections to
99+
# match the output shape of the QKVParallelLinear implementation even if
100+
# it seems to be redundant.
101+
# The concat and the following split will be noop, and should be
102+
# optimized away by the compiler.
103+
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
104+
output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \
105+
self.skip_bias_add else None
106+
if not self.return_bias:
107+
return qkv_proj
108+
return qkv_proj, output_bias
109+
110+
111+
def partition_column_parallel_linear(layer: torch.nn.Module,
112+
mesh: xs.Mesh) -> torch.nn.Module:
113+
assert isinstance(layer, ColumnParallelLinear)
114+
xs.mark_sharding(layer.weight, mesh, ('x', None))
115+
logger.debug("Applied column-parallel sharding to %s", layer)
116+
return layer
117+
118+
119+
def partition_row_parallel_linear(layer: torch.nn.Module,
120+
mesh: xs.Mesh) -> torch.nn.Module:
121+
assert isinstance(layer, RowParallelLinear)
122+
xs.mark_sharding(layer.weight, mesh, (None, 'x'))
123+
logger.debug("Applied row-parallel sharding to %s", layer)
124+
return layer
125+
126+
127+
def partition_qkv_parallel_linear(layer: torch.nn.Module,
128+
mesh: xs.Mesh) -> torch.nn.Module:
129+
assert isinstance(layer, QKVParallelLinear)
130+
xla_layer = XlaQKVParallelLinear(layer, mesh)
131+
logger.debug("Applied qkv parallel sharding to %s", layer)
132+
return xla_layer
133+
134+
135+
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([
136+
("QKVParallelLinear", partition_qkv_parallel_linear),
137+
("ColumnParallelLinear", partition_column_parallel_linear),
138+
("RowParallelLinear", partition_row_parallel_linear),
139+
])
140+
141+
142+
def get_fqn(module):
143+
# Get the fully qualified name of the module
144+
return module.__class__.__qualname__
145+
146+
147+
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
148+
"""
149+
Recursively check a PyTorch model and apply appropriate sharding based on
150+
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
151+
152+
Args:
153+
model: torch.nn.Module to process
154+
mesh: An XLA SPMD mesh object used for sharding
155+
"""
156+
157+
def _process_module(module, name=None, parent=None):
158+
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
159+
if get_fqn(module) == module_type:
160+
wrapped_module = wrapping_func(module, mesh)
161+
162+
assert parent is not None and name is not None, (
163+
"Top Level module is not expected to be wrapped.")
164+
if wrapped_module is not module:
165+
# Wrapped module and module are different py object.
166+
# The original module should be replaced by the
167+
# wrapped_module.
168+
logger.debug("replace %s with %s", module, wrapped_module)
169+
setattr(parent, name, wrapped_module)
170+
171+
module = wrapped_module
172+
break
173+
174+
for child_name, child_module in list(module.named_children()):
175+
_process_module(child_module, child_name, module)
176+
177+
_process_module(model)

0 commit comments

Comments
 (0)