Skip to content

Torch compile first call of prepare_cos_sin method #1395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 45 additions & 73 deletions vllm/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,42 +218,9 @@ def forward_hook(module, args, output):
modify_model_layers(child_module, suffix_list, n, counter)


def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
a RotaryEmbedding layer and return the full path to that layer as a list
of names.
If no such layer is found, it returns None.
"""

def find_rope_layer(parent, path):
# Base case: check if this parent is None
if parent is None:
return None

# Check if the current layer is a RotaryEmbedding
if hasattr(parent, 'named_children'):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
if result is not None:
return result
return None

# Start the search from the top level model
path_to_rope = find_rope_layer(model, [])

# Return the result if found, otherwise None
return path_to_rope


class HpuModelAdapter(torch.nn.Module):

def __init__(self, model, vllm_config, layer_names):
def __init__(self, model, vllm_config):
super().__init__()
self.model = model
self.prefill_use_fusedsdpa = get_config(
Expand All @@ -263,7 +230,38 @@ def __init__(self, model, vllm_config, layer_names):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
self.layer_names = layer_names
self._rotary_embed_module = self._get_rotary_embedding_module(
self.model)
self._rotary_prepare_cos_sin = self._get_prepare_cos_sin()

def _get_rotary_embedding_module(self, model: torch.nn.Module):
"""
Dynamically get the RotaryEmbedding layer in the model.
This function will recursively search through the module
hierarchy to find and return a RotaryEmbedding layer.
If no such layer is found, it returns None.
"""
if model is None:
return None

if model.__class__.__name__.endswith("RotaryEmbedding"):
return model

if hasattr(model, 'children'):
for child in model.children():
result = self._get_rotary_embedding_module(child)
if result is not None:
return result
return None

def _get_prepare_cos_sin(self):
if self._rotary_embed_module is not None:
return self._rotary_embed_module.prepare_cos_sin
return None

def _reset_rotary_cos_sin(self):
delattr(self._rotary_embed_module, "cos")
delattr(self._rotary_embed_module, "sin")

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
Expand Down Expand Up @@ -349,34 +347,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
device, dtype)
return attn_metadata

def _prepare_cos_sin(self, positions):
"""Navigate through the model using the provided path and call
the prepare_cos_sin method on the 'RotaryEmbedding' layer."""

current_module = self.model # Start from the top level of the model

for layer in self.layer_names:
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)

# Check if the current layer is a name in a module
if isinstance(
layer,
str) and not isinstance(layer, int): # Name-based access
current_module = getattr(current_module, layer)
elif isinstance(layer,
int): # Indexed-based access (like Modulelist)
current_module = list(current_module._modules.values())[layer]

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
# TODO(kzawora): something goes VERY WRONG when operating on
# kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata
Expand All @@ -388,13 +358,16 @@ def forward(self, *args, **kwargs):
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
if self.layer_names is not None:
self._prepare_cos_sin(kwargs['positions'])
if self._rotary_prepare_cos_sin is not None:
self._rotary_prepare_cos_sin(
kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin)
attn_meta = kwargs.pop('attn_metadata')
if 'kv_caches' in kwargs:
kwargs.pop('kv_caches')
with set_forward_context(attn_meta, self.vllm_config):
hidden_states = self.model(*args, **kwargs)
if self._rotary_prepare_cos_sin is not None:
self._reset_rotary_cos_sin()
return hidden_states

def compute_logits(self, *args, **kwargs):
Expand Down Expand Up @@ -1786,21 +1759,17 @@ def load_model(self) -> None:
get_target_layer_suffix_list(
model_config.model_type if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m: # noqa: SIM117
self.model = _maybe_wrap_in_hpu_graph(self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope)
vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_device_memory
logger.info("Wrapping in HPUGraph took %.4f GB",
self.model_memory_usage / float(2**30))

with HabanaMemoryProfiler() as m:
self._maybe_compile(self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope)
self._maybe_compile(self.model)
self.model_memory_usage = m.consumed_device_memory
logger.info("Compilation took %.4f GB",
self.model_memory_usage / float(2**30))
Expand All @@ -1810,10 +1779,13 @@ def _maybe_compile(self, *args, **kwargs):
) and not self.vllm_config.model_config.enforce_eager:
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').strip().lower() in ("1", "true"):
compiled_methods = ['_update_metadata']
compiled_methods = [
'_update_metadata', '_rotary_prepare_cos_sin'
]
for method_name in compiled_methods:
method = getattr(self.model, method_name)
self._compile_region(self.model, method_name, method)
if method is not None:
self._compile_region(self.model, method_name, method)
self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
]
Expand Down
4 changes: 2 additions & 2 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@

class HpuModelAdapterEncoderDecoder(HpuModelAdapter):

def __init__(self, model, vllm_config, layer_names, is_causal, sampler):
super().__init__(model, vllm_config, layer_names, is_causal, sampler)
def __init__(self, model, vllm_config, is_causal, sampler):
super().__init__(model, vllm_config, is_causal, sampler)

def _set_cross_block_mapping(self, metadata, batch_size, device, dtype):
mask = torch.arange(0,
Expand Down
120 changes: 44 additions & 76 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,42 +248,9 @@ def forward_hook(module, args, output):
modify_model_layers(child_module, suffix_list, n, counter)


def get_path_to_rope(model: torch.nn.Module):
"""Dynamically get the path to the RotaryEmbedding layer in the model.
This function will recursively search through the module hierarchy to find
a RotaryEmbedding layer and return the full path to that layer as a list
of names.
If no such layer is found, it returns None.
"""

def find_rope_layer(parent, path):
# Base case: check if this parent is None
if parent is None:
return None

# Check if the current layer is a RotaryEmbedding
if hasattr(parent, 'named_children'):
for child_name, child_module in parent.named_children():
# If the current child is of type RotaryEmbedding,
# return the full path
if child_module.__class__.__name__.endswith("RotaryEmbedding"):
return path + [child_name]
# Otherwise, recurse into this child to check its children
result = find_rope_layer(child_module, path + [child_name])
if result is not None:
return result
return None

# Start the search from the top level model
path_to_rope = find_rope_layer(model, [])

# Return the result if found, otherwise None
return path_to_rope


class HpuModelAdapter(torch.nn.Module):

def __init__(self, model, vllm_config, layer_names, is_causal, sampler):
def __init__(self, model, vllm_config, is_causal, sampler):
super().__init__()
self.model = model
self.prefill_use_fusedsdpa = get_config(
Expand All @@ -294,10 +261,41 @@ def __init__(self, model, vllm_config, layer_names, is_causal, sampler):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
self.layer_names = layer_names
self.is_pooler = hasattr(self.model, "_pooler")
self.is_causal = is_causal
self.use_merged_prefill = VLLM_MERGED_PREFILL
self._rotary_embed_module = self._get_rotary_embedding_module(
self.model)
self._rotary_prepare_cos_sin = self._get_prepare_cos_sin()

def _get_rotary_embedding_module(self, model: torch.nn.Module):
"""
Dynamically get the RotaryEmbedding layer in the model.
This function will recursively search through the module
hierarchy to find and return a RotaryEmbedding layer.
If no such layer is found, it returns None.
"""
if model is None:
return None

if model.__class__.__name__.endswith("RotaryEmbedding"):
return model

if hasattr(model, 'children'):
for child in model.children():
result = self._get_rotary_embedding_module(child)
if result is not None:
return result
return None

def _get_prepare_cos_sin(self):
if self._rotary_embed_module is not None:
return self._rotary_embed_module.prepare_cos_sin
return None

def _reset_rotary_cos_sin(self):
delattr(self._rotary_embed_module, "cos")
delattr(self._rotary_embed_module, "sin")

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
Expand Down Expand Up @@ -415,38 +413,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
device, dtype)
return attn_metadata

def _prepare_cos_sin(self, positions):
"""Navigate through the model using the provided path and call
the prepare_cos_sin method on the 'RotaryEmbedding' layer."""

current_module = self.model # Start from the top level of the model

for layer in self.layer_names:
if layer.isdigit(): # Check if the layer is an index
layer = int(layer)

# Check if the current layer is a name in a module
if isinstance(
layer,
str) and not isinstance(layer, int): # Name-based access
current_module = getattr(current_module, layer)
elif isinstance(layer,
int): # Indexed-based access (like ModuleList)
module_list = list(current_module._modules.values())
if layer >= len(module_list):
# for MTP models, last layer is MTP layer
layer = -1
current_module = module_list[layer]

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
a 'prepare_cos_sin' method.")

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
Expand All @@ -464,13 +430,16 @@ def forward(self, *args, **kwargs):
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
model_config = getattr(self.model, "config", None)
model_is_mrope = uses_mrope(model_config)
if self.layer_names is not None and not model_is_mrope:
self._prepare_cos_sin(kwargs['positions'])
if self._rotary_prepare_cos_sin is not None and not model_is_mrope:
self._rotary_prepare_cos_sin(
kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin)
attn_meta = kwargs.pop('attn_metadata')
if 'kv_caches' in kwargs:
kwargs.pop('kv_caches')
with set_forward_context(attn_meta, self.vllm_config, virtual_engine):
hidden_states = self.model(*args, **kwargs)
if self._rotary_prepare_cos_sin is not None and not model_is_mrope:
self._reset_rotary_cos_sin()
if not get_pp_group().is_last_rank:
return hidden_states
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Expand Down Expand Up @@ -970,7 +939,6 @@ def load_model(self) -> None:
model_config.
model_type if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()

if self.is_pooler:
Expand All @@ -979,15 +947,12 @@ def load_model(self) -> None:
self.model = self._maybe_wrap_in_hpu_graph(
self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope,
is_causal=self.is_causal,
sampler=self.sampler)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)
with HabanaMemoryProfiler() as m_wrap:
self._maybe_compile(self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope)
self._maybe_compile(self.model)
msg = f"Compiling took {m_wrap.get_summary_string()}"
logger.info(msg)

Expand Down Expand Up @@ -1029,10 +994,13 @@ def _maybe_compile(self, *args, **kwargs):
) and not self.vllm_config.model_config.enforce_eager:
if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').strip().lower() in ("1", "true"):
compiled_methods = ['_update_metadata']
compiled_methods = [
'_update_metadata', '_rotary_prepare_cos_sin'
]
for method_name in compiled_methods:
method = getattr(self.model, method_name)
self._compile_region(self.model, method_name, method)
if method is not None:
self._compile_region(self.model, method_name, method)

self.regional_compilation_layers_list = [
RMSNorm, VocabParallelEmbedding
Expand Down