Skip to content

Commit 495b2df

Browse files
committed
Enable compile for all hybrid models
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent a6eeae4 commit 495b2df

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

vllm/model_executor/models/bamba.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from vllm import envs
1313
from vllm.attention.layer import Attention
14+
from vllm.compilation.decorators import support_torch_compile
1415
from vllm.config import CacheConfig, VllmConfig
1516
from vllm.distributed import get_tensor_model_parallel_world_size
1617
from vllm.distributed.parallel_state import get_pp_group
@@ -122,11 +123,10 @@ def forward(
122123
hidden_states, residual = self.input_layernorm(
123124
hidden_states, residual)
124125

125-
hidden_states = self.mamba(hidden_states, mamba_cache_params,
126-
mamba2_metadata)
126+
output = torch.empty_like(hidden_states)
127+
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
127128
# Fully Connected
128-
hidden_states, residual = self.pre_ff_layernorm(
129-
hidden_states, residual)
129+
hidden_states, residual = self.pre_ff_layernorm(output, residual)
130130
hidden_states = self.feed_forward(hidden_states)
131131
return hidden_states, residual
132132

@@ -169,7 +169,7 @@ def __init__(
169169
self.max_position_embeddings = max_position_embeddings
170170

171171
if hasattr(config, "partial_rotary_factor"):
172-
rotary_dim = self.head_dim * config.partial_rotary_factor
172+
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
173173
elif hasattr(config, "attn_rotary_emb"):
174174
rotary_dim = config.attn_rotary_emb # for backward compatibility
175175
else:
@@ -258,6 +258,7 @@ def forward(
258258
}
259259

260260

261+
@support_torch_compile
261262
class BambaModel(nn.Module):
262263

263264
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/falcon_h1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm import envs
1212
from vllm.attention.layer import Attention
13+
from vllm.compilation.decorators import support_torch_compile
1314
from vllm.config import CacheConfig, VllmConfig
1415
from vllm.distributed import get_tensor_model_parallel_world_size
1516
from vllm.distributed.parallel_state import get_pp_group
@@ -179,13 +180,15 @@ def forward(
179180
mamba2_metadata: Mamba2Metadata,
180181
**kwargs,
181182
):
182-
hidden_states = self.mamba(
183+
output = torch.empty_like(hidden_states)
184+
self.mamba(
183185
hidden_states,
186+
output,
184187
mamba_cache_params,
185188
mamba2_metadata=mamba2_metadata,
186189
mup_vector=self.mup_vector,
187190
)
188-
return hidden_states, residual
191+
return output, residual
189192

190193

191194
class FalconH1AttentionDecoderLayer(nn.Module):
@@ -398,6 +401,7 @@ def forward(
398401
return hidden_states
399402

400403

404+
@support_torch_compile
401405
class FalconH1Model(nn.Module):
402406

403407
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/mamba2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm import envs
1212
from vllm.attention.backends.abstract import AttentionMetadata
13+
from vllm.compilation.decorators import support_torch_compile
1314
from vllm.config import VllmConfig
1415
from vllm.distributed.parallel_state import get_pp_group
1516
from vllm.forward_context import get_forward_context
@@ -79,11 +80,12 @@ def forward(
7980
else:
8081
hidden_states, residual = self.norm(hidden_states, residual)
8182

82-
hidden_states = self.mixer(hidden_states, mamba_cache_params,
83-
mamba2_metadata)
84-
return hidden_states, residual
83+
output = torch.empty_like(hidden_states)
84+
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
85+
return output, residual
8586

8687

88+
@support_torch_compile
8789
class Mamba2Model(nn.Module):
8890

8991
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/nemotron_h.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from vllm import envs
2727
from vllm.attention.layer import Attention
28+
from vllm.compilation.decorators import support_torch_compile
2829
from vllm.config import CacheConfig, VllmConfig
2930
from vllm.distributed import get_tensor_model_parallel_world_size
3031
from vllm.distributed.parallel_state import get_pp_group
@@ -172,9 +173,9 @@ def forward(
172173
else:
173174
hidden_states, residual = self.norm(hidden_states, residual)
174175

175-
hidden_states = self.mixer(hidden_states, mamba_cache_params,
176-
mamba2_metadata)
177-
return hidden_states, residual
176+
output = torch.empty_like(hidden_states)
177+
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
178+
return output, residual
178179

179180

180181
class NemotronHAttention(nn.Module):
@@ -292,6 +293,7 @@ def forward(
292293
}
293294

294295

296+
@support_torch_compile
295297
class NemotronHModel(nn.Module):
296298

297299
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

0 commit comments

Comments
 (0)