Skip to content

Commit e9604ff

Browse files
🐛 fix tensor parallel (#301)
# Description Fixes a bug introduced in #283 where the non-driver workers did not cache the output tokens for the next decode iteration. This also allows TP tests with TP=2 to run on cpu, so that we can catch these bugs on GHA runs. --------- Signed-off-by: Joe Runde <joe@joerun.de> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com>
1 parent cadf224 commit e9604ff

File tree

6 files changed

+23
-14
lines changed

6 files changed

+23
-14
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def remote_openai_server(request):
125125

126126
if 'tp_size' in params:
127127
tp_size = params['tp_size']
128-
skip_unsupported_tp_size(int(tp_size))
128+
skip_unsupported_tp_size(int(tp_size), backend)
129129
server_args.extend(["--tensor-parallel-size", str(tp_size)])
130130

131131
try:

tests/e2e/test_spyre_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
pytest.param(2, marks=pytest.mark.multi),
2424
pytest.param(4, marks=pytest.mark.multi),
2525
pytest.param(8, marks=pytest.mark.multi),
26-
])
26+
],
27+
ids=lambda val: f"TP({val})")
2728
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2829
def test_output(
2930
model: str,
@@ -45,7 +46,7 @@ def test_output(
4546
After debugging, DISABLE_ASSERTS should be reset to 'False'.
4647
'''
4748

48-
skip_unsupported_tp_size(tp_size)
49+
skip_unsupported_tp_size(tp_size, backend)
4950

5051
prompts = get_chicken_soup_prompts(4)
5152

tests/e2e/test_spyre_online.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
pytest.param(2, marks=pytest.mark.multi),
1010
pytest.param(4, marks=pytest.mark.multi),
1111
pytest.param(8, marks=pytest.mark.multi),
12-
])
12+
],
13+
ids=lambda val: f"TP({val})")
1314
@pytest.mark.parametrize("backend", get_spyre_backend_list())
1415
@pytest.mark.parametrize("warmup_shape", [[
1516
(64, 20, 1),

tests/e2e/test_spyre_prompt_logprobs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2020
@pytest.mark.parametrize("model", get_spyre_model_list())
2121
@pytest.mark.parametrize("tp_size", [
22-
pytest.param(1, id="tp_size"),
23-
pytest.param(2, marks=pytest.mark.multi, id="tp_size"),
24-
pytest.param(4, marks=pytest.mark.multi, id="tp_size")
25-
])
22+
pytest.param(1),
23+
pytest.param(2, marks=pytest.mark.multi),
24+
pytest.param(4, marks=pytest.mark.multi)
25+
],
26+
ids=lambda val: f"TP({val})")
2627
def test_prompt_logprobs(
2728
backend: str,
2829
model: str,
@@ -33,7 +34,7 @@ def test_prompt_logprobs(
3334
This test checks the prompt_logprobs output from vllm against a reference
3435
implementation using huggingface.
3536
'''
36-
skip_unsupported_tp_size(tp_size)
37+
skip_unsupported_tp_size(tp_size, backend)
3738
num_prompt_logprobs = 5
3839

3940
prompts = get_chicken_soup_prompts(4)

tests/spyre_util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,13 @@ def create_random_request(
548548
**extra_kwargs)
549549

550550

551-
def skip_unsupported_tp_size(size: int):
551+
def skip_unsupported_tp_size(size: int, backend: str):
552+
if backend in ["eager", "inductor"]:
553+
# Spyre cards aren't required for running TP on CPU backends
554+
# But it's really slow to run tp > 2
555+
if size > 2:
556+
pytest.skip("Skipping TP test on CPU with TP size > 2")
557+
return
552558
cards = int(os.getenv("AIU_WORLD_SIZE", "0"))
553559
if cards < size:
554560
pytest.skip(f"Cannot run TP size {size}: "

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,6 @@ def execute_model(
397397
masks=model_input.input_masks,
398398
is_prompt=model_input.is_prompt)
399399

400-
# Only perform sampling in the driver worker.
401-
if not self.is_driver_worker:
402-
return EMPTY_MODEL_RUNNER_OUTPUT
403-
404400
# Compute the logits.
405401
logits = self.model.compute_logits(hidden_states, None)
406402

@@ -434,6 +430,10 @@ def execute_model(
434430
prompt_logprobs_dicts = self._get_prompt_logprobs_dict(
435431
logits=logits, model_inputs=model_input)
436432

433+
# Only return outputs from the driver worker
434+
if not self.is_driver_worker:
435+
return EMPTY_MODEL_RUNNER_OUTPUT
436+
437437
model_output = ModelRunnerOutput(
438438
req_ids=list(req_id_to_index.keys()),
439439
req_id_to_index=req_id_to_index,

0 commit comments

Comments
 (0)