Skip to content

Commit 4807582

Browse files
vadiklyutiywwl2755-google
authored andcommitted
[PERF] Speedup of MRoPE prepare inputs (vllm-project#19939)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
1 parent 0af7378 commit 4807582

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import math
2727
from typing import Any, Optional, Union
2828

29+
import numpy as np
2930
import torch
3031
import torch.nn as nn
3132
from transformers import PretrainedConfig
@@ -1458,15 +1459,14 @@ def get_next_input_positions(
14581459
]
14591460

14601461
@staticmethod
1461-
def get_next_input_positions_tensor(
1462-
mrope_position_delta: int,
1463-
context_len: int,
1464-
seq_len: int,
1465-
) -> torch.Tensor:
1466-
return torch.arange(
1467-
mrope_position_delta + context_len,
1468-
mrope_position_delta + seq_len,
1469-
).expand(3, -1)
1462+
def get_next_input_positions_tensor(out: np.ndarray, out_offset: int,
1463+
mrope_position_delta: int,
1464+
context_len: int, num_new_tokens: int):
1465+
1466+
values = np.arange(mrope_position_delta + context_len,
1467+
mrope_position_delta + context_len + num_new_tokens,
1468+
dtype=out.dtype)
1469+
out[:, out_offset:out_offset + num_new_tokens] = values
14701470

14711471
@classmethod
14721472
def omni_get_updates_use_audio_in_video(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def __init__(
262262
dtype=torch.int64,
263263
device="cpu",
264264
pin_memory=self.pin_memory)
265+
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
265266

266267
# Only relevant for models using ALiBi (e.g, MPT)
267268
self.use_alibi = check_use_alibi(model_config)
@@ -889,15 +890,13 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
889890
dst_start = mrope_pos_ptr
890891
dst_end = mrope_pos_ptr + completion_part_len
891892

892-
self.mrope_positions_cpu[:, dst_start:dst_end] = \
893-
MRotaryEmbedding.get_next_input_positions_tensor(
894-
req.mrope_position_delta,
895-
context_len=num_computed_tokens +
896-
prompt_part_len,
897-
seq_len=num_computed_tokens +
898-
prompt_part_len +
899-
completion_part_len,
900-
)
893+
MRotaryEmbedding.get_next_input_positions_tensor(
894+
out=self.mrope_positions_np,
895+
out_offset=dst_start,
896+
mrope_position_delta=req.mrope_position_delta,
897+
context_len=num_computed_tokens + prompt_part_len,
898+
num_new_tokens=completion_part_len,
899+
)
901900

902901
mrope_pos_ptr += completion_part_len
903902

0 commit comments

Comments
 (0)