diff --git a/vllm/v1/worker/hpu_model_runner.py b/vllm/v1/worker/hpu_model_runner.py index 6894b167142..747c090a438 100644 --- a/vllm/v1/worker/hpu_model_runner.py +++ b/vllm/v1/worker/hpu_model_runner.py @@ -252,42 +252,9 @@ def forward_hook(module, args, output): modify_model_layers(child_module, suffix_list, n, counter) -def get_path_to_rope(model: torch.nn.Module): - """Dynamically get the path to the RotaryEmbedding layer in the model. - This function will recursively search through the module hierarchy to find - a RotaryEmbedding layer and return the full path to that layer as a list - of names. - If no such layer is found, it returns None. - """ - - def find_rope_layer(parent, path): - # Base case: check if this parent is None - if parent is None: - return None - - # Check if the current layer is a RotaryEmbedding - if hasattr(parent, 'named_children'): - for child_name, child_module in parent.named_children(): - # If the current child is of type RotaryEmbedding, - # return the full path - if child_module.__class__.__name__.endswith("RotaryEmbedding"): - return path + [child_name] - # Otherwise, recurse into this child to check its children - result = find_rope_layer(child_module, path + [child_name]) - if result is not None: - return result - return None - - # Start the search from the top level model - path_to_rope = find_rope_layer(model, []) - - # Return the result if found, otherwise None - return path_to_rope - - class HpuModelAdapter(torch.nn.Module): - def __init__(self, model, vllm_config, layer_names): + def __init__(self, model, vllm_config): super().__init__() self.model = model self.prefill_use_fusedsdpa = get_config( @@ -297,7 +264,38 @@ def __init__(self, model, vllm_config, layer_names): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype - self.layer_names = layer_names + self._rotary_embed_module = self._get_rotary_embedding_module( + self.model) + self._rotary_prepare_cos_sin = self._get_prepare_cos_sin() + + def _get_rotary_embedding_module(self, model: torch.nn.Module): + """ + Dynamically get the RotaryEmbedding layer in the model. + This function will recursively search through the module + hierarchy to find and return a RotaryEmbedding layer. + If no such layer is found, it returns None. + """ + if model is None: + return None + + if model.__class__.__name__.endswith("RotaryEmbedding"): + return model + + if hasattr(model, 'children'): + for child in model.children(): + result = self._get_rotary_embedding_module(child) + if result is not None: + return result + return None + + def _get_prepare_cos_sin(self): + if self._rotary_embed_module is not None: + return self._rotary_embed_module.prepare_cos_sin + return None + + def _reset_rotary_cos_sin(self): + delattr(self._rotary_embed_module, "cos") + delattr(self._rotary_embed_module, "sin") def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): @@ -386,34 +384,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, device, dtype) return attn_metadata - def _prepare_cos_sin(self, positions): - """Navigate through the model using the provided path and call - the prepare_cos_sin method on the 'RotaryEmbedding' layer.""" - - current_module = self.model # Start from the top level of the model - - for layer in self.layer_names: - if layer.isdigit(): # Check if the layer is an index - layer = int(layer) - - # Check if the current layer is a name in a module - if isinstance( - layer, - str) and not isinstance(layer, int): # Name-based access - current_module = getattr(current_module, layer) - elif isinstance(layer, - int): # Indexed-based access (like Modulelist) - current_module = list(current_module._modules.values())[layer] - - # At the end, we should be at the RotaryEmbedding layer. - if hasattr(current_module, 'prepare_cos_sin'): - current_module.prepare_cos_sin( - positions, recompute_cos_sin=self.recompute_cos_sin) - else: - raise AttributeError( - "The module at the end of the path does not have \ - a 'prepare_cos_sin' method.") - def forward(self, *args, **kwargs): # TODO(kzawora): something goes VERY WRONG when operating on # kwargs['attn_metadata'].slot_mapping, compared to untrimmed metadata @@ -425,13 +395,16 @@ def forward(self, *args, **kwargs): kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.device, self.dtype) - if self.layer_names is not None: - self._prepare_cos_sin(kwargs['positions']) + if self._rotary_prepare_cos_sin is not None: + self._rotary_prepare_cos_sin( + kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) attn_meta = kwargs.pop('attn_metadata') if 'kv_caches' in kwargs: kwargs.pop('kv_caches') with set_forward_context(attn_meta, self.vllm_config): hidden_states = self.model(*args, **kwargs) + if self._rotary_prepare_cos_sin is not None: + self._reset_rotary_cos_sin() return hidden_states def compute_logits(self, *args, **kwargs): @@ -1707,21 +1680,17 @@ def load_model(self) -> None: get_target_layer_suffix_list( model_config.model_type if model_config is not None else None), hidden_layer_markstep_interval) - path_to_rope = get_path_to_rope(self.model) torch.hpu.synchronize() with HabanaMemoryProfiler() as m: # noqa: SIM117 self.model = _maybe_wrap_in_hpu_graph(self.model, - vllm_config=self.vllm_config, - layer_names=path_to_rope) + vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_device_memory logger.info("Wrapping in HPUGraph took %.4f GB", self.model_memory_usage / float(2**30)) with HabanaMemoryProfiler() as m: - self._maybe_compile(self.model, - vllm_config=self.vllm_config, - layer_names=path_to_rope) + self._maybe_compile(self.model) self.model_memory_usage = m.consumed_device_memory logger.info("Compilation took %.4f GB", self.model_memory_usage / float(2**30)) @@ -1731,10 +1700,13 @@ def _maybe_compile(self, *args, **kwargs): ) and not self.vllm_config.model_config.enforce_eager: if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').strip().lower() in ("1", "true"): - compiled_methods = ['_update_metadata'] + compiled_methods = [ + '_update_metadata', '_rotary_prepare_cos_sin' + ] for method_name in compiled_methods: method = getattr(self.model, method_name) - self._compile_region(self.model, method_name, method) + if method is not None: + self._compile_region(self.model, method_name, method) self.regional_compilation_layers_list = [ RMSNorm, VocabParallelEmbedding ] diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index e68e76ae76e..d6b4f2e3de5 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -40,8 +40,8 @@ class HpuModelAdapterEncoderDecoder(HpuModelAdapter): - def __init__(self, model, vllm_config, layer_names, is_causal, sampler): - super().__init__(model, vllm_config, layer_names, is_causal, sampler) + def __init__(self, model, vllm_config, is_causal, sampler): + super().__init__(model, vllm_config, is_causal, sampler) # We only wrap the language model in HPU graph because some Ops in # vision model will fallback to CPU and cause the graph building fail. diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b3b6bb4749d..178caa4791d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -288,42 +288,9 @@ def forward_hook(module, args, output): modify_model_layers(child_module, suffix_list, n, counter) -def get_path_to_rope(model: torch.nn.Module): - """Dynamically get the path to the RotaryEmbedding layer in the model. - This function will recursively search through the module hierarchy to find - a RotaryEmbedding layer and return the full path to that layer as a list - of names. - If no such layer is found, it returns None. - """ - - def find_rope_layer(parent, path): - # Base case: check if this parent is None - if parent is None: - return None - - # Check if the current layer is a RotaryEmbedding - if hasattr(parent, 'named_children'): - for child_name, child_module in parent.named_children(): - # If the current child is of type RotaryEmbedding, - # return the full path - if child_module.__class__.__name__.endswith("RotaryEmbedding"): - return path + [child_name] - # Otherwise, recurse into this child to check its children - result = find_rope_layer(child_module, path + [child_name]) - if result is not None: - return result - return None - - # Start the search from the top level model - path_to_rope = find_rope_layer(model, []) - - # Return the result if found, otherwise None - return path_to_rope - - class HpuModelAdapter(torch.nn.Module): - def __init__(self, model, vllm_config, layer_names, is_causal, sampler): + def __init__(self, model, vllm_config, is_causal, sampler): super().__init__() self.model = model self.prefill_use_fusedsdpa = get_config( @@ -334,7 +301,6 @@ def __init__(self, model, vllm_config, layer_names, is_causal, sampler): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.dtype = vllm_config.model_config.dtype - self.layer_names = layer_names self.is_pooler = hasattr(self.model, "_pooler") self.is_causal = is_causal self.use_merged_prefill = get_config().merged_prefill @@ -354,6 +320,39 @@ def __init__(self, model, vllm_config, layer_names, is_causal, sampler): self.model.visual = htorch.hpu.wrap_in_hpu_graph( self.model.visual, disable_tensor_cache=True) + self._rotary_embed_module = self._get_rotary_embedding_module( + self.model) + self._rotary_prepare_cos_sin = self._get_prepare_cos_sin() + + def _get_rotary_embedding_module(self, model: torch.nn.Module): + """ + Dynamically get the RotaryEmbedding layer in the model. + This function will recursively search through the module + hierarchy to find and return a RotaryEmbedding layer. + If no such layer is found, it returns None. + """ + if model is None: + return None + + if model.__class__.__name__.endswith("RotaryEmbedding"): + return model + + if hasattr(model, 'children'): + for child in model.children(): + result = self._get_rotary_embedding_module(child) + if result is not None: + return result + return None + + def _get_prepare_cos_sin(self): + if self._rotary_embed_module is not None: + return self._rotary_embed_module.prepare_cos_sin + return None + + def _reset_rotary_cos_sin(self): + delattr(self._rotary_embed_module, "cos") + delattr(self._rotary_embed_module, "sin") + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): if (attn_metadata is None @@ -470,38 +469,6 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, device, dtype) return attn_metadata - def _prepare_cos_sin(self, positions): - """Navigate through the model using the provided path and call - the prepare_cos_sin method on the 'RotaryEmbedding' layer.""" - - current_module = self.model # Start from the top level of the model - - for layer in self.layer_names: - if layer.isdigit(): # Check if the layer is an index - layer = int(layer) - - # Check if the current layer is a name in a module - if isinstance( - layer, - str) and not isinstance(layer, int): # Name-based access - current_module = getattr(current_module, layer) - elif isinstance(layer, - int): # Indexed-based access (like ModuleList) - module_list = list(current_module._modules.values()) - if layer >= len(module_list): - # for MTP models, last layer is MTP layer - layer = -1 - current_module = module_list[layer] - - # At the end, we should be at the RotaryEmbedding layer. - if hasattr(current_module, 'prepare_cos_sin'): - current_module.prepare_cos_sin( - positions, recompute_cos_sin=self.recompute_cos_sin) - else: - raise AttributeError( - "The module at the end of the path does not have \ - a 'prepare_cos_sin' method.") - def compute_input_embeddings_for_mrope(self, **kwargs): if not self.model_is_mrope: return None @@ -543,8 +510,9 @@ def forward(self, *args, **kwargs): input_ids.device, self.dtype) if 'lora_mask' in kwargs: LoraMask.setLoraMask(kwargs.pop('lora_mask')) - if self.layer_names is not None and not self.model_is_mrope: - self._prepare_cos_sin(kwargs['positions']) + if self._rotary_prepare_cos_sin is not None and not self.model_is_mrope: + self._rotary_prepare_cos_sin( + kwargs['positions'], recompute_cos_sin=self.recompute_cos_sin) if self.model_is_mrope: # inputs_embeds was computed on execute_model # now we always want to use the inputs_embeds @@ -561,6 +529,9 @@ def forward(self, *args, **kwargs): virtual_engine, dp_awared_padding=self.dp_awared_padding): hidden_states = self.model(*args, **kwargs) + if self._rotary_prepare_cos_sin is not None and \ + not self.model_is_mrope: + self._reset_rotary_cos_sin() if not get_pp_group().is_last_rank: return hidden_states hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) @@ -1069,7 +1040,6 @@ def load_model(self) -> None: model_config. model_type if model_config is not None else None), hidden_layer_markstep_interval) - path_to_rope = get_path_to_rope(self.model) torch.hpu.synchronize() if self.is_pooler: @@ -1078,15 +1048,12 @@ def load_model(self) -> None: self.model = self._maybe_wrap_in_hpu_graph( self.model, vllm_config=self.vllm_config, - layer_names=path_to_rope, is_causal=self.is_causal, sampler=self.sampler) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) with HabanaMemoryProfiler() as m_wrap: - self._maybe_compile(self.model, - vllm_config=self.vllm_config, - layer_names=path_to_rope) + self._maybe_compile(self.model) msg = f"Compiling took {m_wrap.get_summary_string()}" logger.info(msg) @@ -1143,10 +1110,13 @@ def _maybe_compile(self, *args, **kwargs): ) and not self.vllm_config.model_config.enforce_eager: if os.getenv('VLLM_REGIONAL_COMPILATION', 'true').strip().lower() in ("1", "true"): - compiled_methods = ['_update_metadata'] + compiled_methods = [ + '_update_metadata', '_rotary_prepare_cos_sin' + ] for method_name in compiled_methods: method = getattr(self.model, method_name) - self._compile_region(self.model, method_name, method) + if method is not None: + self._compile_region(self.model, method_name, method) self.regional_compilation_layers_list = [ RMSNorm, VocabParallelEmbedding