From 7e7a2422f4d96d681105c434689e256a2f7f102f Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 16 Jul 2025 02:18:57 -0700 Subject: [PATCH 1/5] Add test case for compiling multiple graphs Signed-off-by: Yong Hoon Shin --- .../compile/piecewise/test_multiple_graphs.py | 283 ++++++++++++++++++ vllm/compilation/decorators.py | 20 +- 2 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 tests/compile/piecewise/test_multiple_graphs.py diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py new file mode 100644 index 00000000000..dd9da56bfd1 --- /dev/null +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test (piecewise) compilation with a simple model where multiple submodules +are compiled and graph captured separately. +""" +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.backends import set_model_tag +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.envs import VLLM_USE_V1 +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 +HIDDEN_SIZE = 1024 +RANDOM_SEED = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + +@support_torch_compile +class ParentModel(nn.Module): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class Attention(nn.Module): + def __init__(self, mlp_size: int, hidden_size: int) -> None: + super().__init__() + self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) + self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False) + + nn.init.xavier_normal_(self.pre_attn.weight.data, + generator=torch.Generator().manual_seed( + RANDOM_SEED), + gain=0.001) + nn.init.xavier_normal_(self.post_attn.weight.data, + generator=torch.Generator().manual_seed( + RANDOM_SEED), + gain=0.001) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pre_attn(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.post_attn(x) + return x + +@support_torch_compile +class CompiledAttention(nn.Module): + def __init__(self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.attn = Attention(mlp_size, hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(x) + +@support_torch_compile +class CompiledAttentionTwo(CompiledAttention): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(x) + x + +@ignore_torch_compile +class SimpleModel(ParentModel): + def __init__(self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.attn_one = Attention(mlp_size, hidden_size) + self.attn_two = Attention(mlp_size, hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.attn_one(x) + x = self.attn_two(x) + x + return x + +@ignore_torch_compile +class SimpleModelWithTwoGraphs(ParentModel): + def __init__(self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Test will fail without `set_model_tag`` here with error: + # "ValueError: too many values to unpack (expected 3)" + # This is because CompiledAttention and CompiledAttentionTwo + # have different implmentations but the same torch.compile + # cache dir will be used by default + with set_model_tag("attn_one"): + self.attn_one = CompiledAttention( + mlp_size=mlp_size, + hidden_size=hidden_size, + vllm_config=vllm_config, + prefix=f"{prefix}.attn_one", + ) + with set_model_tag("attn_two"): + self.attn_two = CompiledAttentionTwo( + mlp_size=mlp_size, + hidden_size=hidden_size, + vllm_config=vllm_config, + prefix=f"{prefix}.attn_two", + ) + + self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz = x.shape[0] + # CUDAGraph expects same tensor addresses for each run + self.hidden_states[:bsz].copy_(x) + x = self.attn_one(self.hidden_states[:bsz]) + self.hidden_states[:bsz].copy_(x) + x = self.attn_two(self.hidden_states[:bsz]) + return x + + +def test_ignore_torch_compile_decorator(): + assert VLLM_USE_V1 + + # piecewise + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) + + with set_current_vllm_config(vllm_config): + model = SimpleModel( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='' + ).eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ), set_forward_context({}, vllm_config=vllm_config): + # first run is for compile + model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + + # run cudagraph captured sizes + model(torch.randn(2, MLP_SIZE).cuda()) + model(torch.randn(1, MLP_SIZE).cuda()) + + +@torch.inference_mode +def run_model(vllm_config, model: nn.Module): + with set_forward_context({}, vllm_config=vllm_config): + # Pre-allocate memory for CUDAGraph which expects + # static tensor addresses + inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() + + # First run is for compile + model(inputs) + + # Run CUDAGraph captured sizes + model(inputs[:2]) + model(inputs[:1]) + + inputs[:2].fill_(1.0) + output = model(inputs[:2]) + + output = output.cpu() + return output.cpu() + +def test_multi_graph_piecewise_compile(monkeypatch): + assert VLLM_USE_V1 + + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + outputs = [] + + # piecewise compile + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) + + with set_current_vllm_config(vllm_config): + model = SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='' + ).eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=6, + # attn_one, attn_two each has 3 piecewise graphs + # (pre_attn, post_attn, silly_attention) each + num_piecewise_capturable_graphs_seen=4, + # attn_one, attn_two has pre_attn and post_attn each, total=4 + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + outputs.append(run_model(vllm_config, model)) + + # no compile or cudagraph + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.NO_COMPILATION, + ) + ) + + with set_current_vllm_config(vllm_config): + model = SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='' + ).eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + outputs.append(run_model(vllm_config, model)) + + assert torch.allclose(outputs[0], outputs[1]) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 05e4ca9f08b..ca461e23035 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -20,9 +20,26 @@ logger = init_logger(__name__) +IGNORE_COMPILE_KEY = "_ignore_compile_vllm" + _T = TypeVar("_T", bound=type[nn.Module]) +def ignore_torch_compile(cls: _T) -> _T: + """ + A decorator to ignore support_torch_compile decorator + on the class. This is useful when a parent class has + a support_torch_compile decorator, but we don't want to + compile the class `cls` that inherits the parent class. + This only ignores compiling the forward of the class the + decorator is applied to. If the class has one or more submodules + that have support_torch_compile decorator applied, compile will + not be ignored for those submodules. + """ + setattr(cls, IGNORE_COMPILE_KEY, True) + return cls + + @overload def support_torch_compile( *, @@ -156,7 +173,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() + ] or not supports_dynamo() or getattr( + self, IGNORE_COMPILE_KEY, False) if self.do_not_compile: return compilation_counter.num_models_seen += 1 From 1167d998cf14c5094e673df5094015c8db33a3f0 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 16 Jul 2025 02:50:29 -0700 Subject: [PATCH 2/5] Lint and minor fixes Signed-off-by: Yong Hoon Shin --- .../compile/piecewise/test_multiple_graphs.py | 183 +++++++++--------- vllm/compilation/decorators.py | 4 +- 2 files changed, 94 insertions(+), 93 deletions(-) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index dd9da56bfd1..1b6056506c6 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -10,7 +10,8 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter -from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 @@ -46,33 +47,37 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, target_lib=silly_lib, ) + @support_torch_compile class ParentModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs) -> None: super().__init__() - + def forward(self, x: torch.Tensor) -> torch.Tensor: return x class Attention(nn.Module): + def __init__(self, mlp_size: int, hidden_size: int) -> None: super().__init__() self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False) - nn.init.xavier_normal_(self.pre_attn.weight.data, - generator=torch.Generator().manual_seed( - RANDOM_SEED), - gain=0.001) - nn.init.xavier_normal_(self.post_attn.weight.data, - generator=torch.Generator().manual_seed( - RANDOM_SEED), - gain=0.001) + # Initialize to same weights for testing + nn.init.xavier_normal_( + self.pre_attn.weight.data, + generator=torch.Generator().manual_seed(RANDOM_SEED), + gain=0.001) + nn.init.xavier_normal_( + self.post_attn.weight.data, + generator=torch.Generator().manual_seed(RANDOM_SEED), + gain=0.001) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre_attn(x) @@ -82,35 +87,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.post_attn(x) return x + @support_torch_compile class CompiledAttention(nn.Module): + def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: super().__init__() self.attn = Attention(mlp_size, hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.attn(x) + @support_torch_compile class CompiledAttentionTwo(CompiledAttention): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.attn(x) + x + @ignore_torch_compile class SimpleModel(ParentModel): + def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.attn_one = Attention(mlp_size, hidden_size) self.attn_two = Attention(mlp_size, hidden_size) @@ -120,19 +131,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.attn_two(x) + x return x + @ignore_torch_compile class SimpleModelWithTwoGraphs(ParentModel): + def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) # Test will fail without `set_model_tag`` here with error: # "ValueError: too many values to unpack (expected 3)" - # This is because CompiledAttention and CompiledAttentionTwo + # This is because CompiledAttention and CompiledAttentionTwo # have different implmentations but the same torch.compile # cache dir will be used by default with set_model_tag("attn_one"): @@ -149,9 +162,9 @@ def __init__(self, vllm_config=vllm_config, prefix=f"{prefix}.attn_two", ) - + self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda() - + def forward(self, x: torch.Tensor) -> torch.Tensor: bsz = x.shape[0] # CUDAGraph expects same tensor addresses for each run @@ -166,29 +179,25 @@ def test_ignore_torch_compile_decorator(): assert VLLM_USE_V1 # piecewise - vllm_config = VllmConfig( - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - ) - ) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) with set_current_vllm_config(vllm_config): - model = SimpleModel( - mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='' - ).eval().cuda() + model = SimpleModel(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ), set_forward_context({}, vllm_config=vllm_config): # first run is for compile model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) @@ -196,7 +205,7 @@ def test_ignore_torch_compile_decorator(): # run cudagraph captured sizes model(torch.randn(2, MLP_SIZE).cuda()) model(torch.randn(1, MLP_SIZE).cuda()) - + @torch.inference_mode def run_model(vllm_config, model: nn.Module): @@ -218,65 +227,57 @@ def run_model(vllm_config, model: nn.Module): output = output.cpu() return output.cpu() + def test_multi_graph_piecewise_compile(monkeypatch): assert VLLM_USE_V1 monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") - + outputs = [] # piecewise compile - vllm_config = VllmConfig( - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - ) - ) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs( - mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='' - ).eval().cuda() + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() with compilation_counter.expect( - num_graphs_seen=2, # two graphs for the model - num_piecewise_graphs_seen=6, - # attn_one, attn_two each has 3 piecewise graphs - # (pre_attn, post_attn, silly_attention) each - num_piecewise_capturable_graphs_seen=4, - # attn_one, attn_two has pre_attn and post_attn each, total=4 - num_backend_compilations=4, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=6, + # attn_one, attn_two each has 3 piecewise graphs + # (pre_attn, post_attn, silly_attention) each + num_piecewise_capturable_graphs_seen=4, + # attn_one, attn_two has pre_attn and post_attn each, total=4 + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append(run_model(vllm_config, model)) # no compile or cudagraph - vllm_config = VllmConfig( - compilation_config=CompilationConfig( - level=CompilationLevel.NO_COMPILATION, - ) - ) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.NO_COMPILATION, )) with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs( - mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='' - ).eval().cuda() + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): outputs.append(run_model(vllm_config, model)) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca461e23035..2b8ec178e33 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -173,8 +173,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or getattr( - self, IGNORE_COMPILE_KEY, False) + ] or not supports_dynamo() or ( + IGNORE_COMPILE_KEY in self.__class__.__dict__) if self.do_not_compile: return compilation_counter.num_models_seen += 1 From 2cb8383804989fcb8a967e6ac4bfd7062484bff0 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 16 Jul 2025 09:19:37 -0700 Subject: [PATCH 3/5] Fix grandchild support_torch_compile ignored Signed-off-by: Yong Hoon Shin --- .../compile/piecewise/test_multiple_graphs.py | 97 +++++++++++++------ vllm/compilation/decorators.py | 21 +++- 2 files changed, 85 insertions(+), 33 deletions(-) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 1b6056506c6..70b376112cf 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -112,26 +112,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.attn(x) + x -@ignore_torch_compile -class SimpleModel(ParentModel): - - def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.attn_one = Attention(mlp_size, hidden_size) - self.attn_two = Attention(mlp_size, hidden_size) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.attn_one(x) - x = self.attn_two(x) + x - return x - - @ignore_torch_compile class SimpleModelWithTwoGraphs(ParentModel): @@ -175,9 +155,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def test_ignore_torch_compile_decorator(): +def test_ignore_torch_compile_decorator(monkeypatch): assert VLLM_USE_V1 + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + # piecewise vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, @@ -186,12 +168,54 @@ def test_ignore_torch_compile_decorator(): cudagraph_capture_sizes=[1, 2], )) + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ), set_forward_context({}, vllm_config=vllm_config): + # first run is for compile + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + # run cudagraph captured sizes + mod_A(torch.randn(2, MLP_SIZE).cuda()) + mod_A(torch.randn(1, MLP_SIZE).cuda()) + with set_current_vllm_config(vllm_config): - model = SimpleModel(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + # B's ignore_torch_compile should override A's support_torch_compile with compilation_counter.expect( num_graphs_seen=0, num_piecewise_graphs_seen=0, @@ -199,12 +223,25 @@ def test_ignore_torch_compile_decorator(): num_backend_compilations=0, num_cudagraph_captured=0, ), set_forward_context({}, vllm_config=vllm_config): - # first run is for compile - model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + mod_B(torch.randn(2, MLP_SIZE).cuda()) + mod_B(torch.randn(1, MLP_SIZE).cuda()) - # run cudagraph captured sizes - model(torch.randn(2, MLP_SIZE).cuda()) - model(torch.randn(1, MLP_SIZE).cuda()) + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ), set_forward_context({}, vllm_config=vllm_config): + mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + mod_C(torch.randn(2, MLP_SIZE).cuda()) + mod_C(torch.randn(1, MLP_SIZE).cuda()) @torch.inference_mode diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 2b8ec178e33..f3592324d8c 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -32,7 +32,12 @@ def ignore_torch_compile(cls: _T) -> _T: a support_torch_compile decorator, but we don't want to compile the class `cls` that inherits the parent class. This only ignores compiling the forward of the class the - decorator is applied to. If the class has one or more submodules + decorator is applied to. + + If the parent has ignore_torch_compile but the child has + support_torch_compile, the child will still be compiled. + + If the class has one or more submodules that have support_torch_compile decorator applied, compile will not be ignored for those submodules. """ @@ -40,6 +45,13 @@ def ignore_torch_compile(cls: _T) -> _T: return cls +def _should_ignore_torch_compile(cls) -> bool: + """ + Check if the class should be ignored for torch.compile. + """ + return getattr(cls, IGNORE_COMPILE_KEY, False) + + @overload def support_torch_compile( *, @@ -165,6 +177,8 @@ def _support_torch_compile( old_init = cls.__init__ + setattr(cls, IGNORE_COMPILE_KEY, False) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config @@ -173,10 +187,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or ( - IGNORE_COMPILE_KEY in self.__class__.__dict__) + ] or not supports_dynamo() or _should_ignore_torch_compile( + self.__class__) if self.do_not_compile: return + compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) From 05001704dc8dd77a89182638d56ad6138c95a8b9 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 16 Jul 2025 10:13:59 -0700 Subject: [PATCH 4/5] Update comment Signed-off-by: Yong Hoon Shin --- tests/compile/piecewise/test_multiple_graphs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 70b376112cf..ae330b8ea0d 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -123,11 +123,11 @@ def __init__(self, prefix: str = '', **kwargs) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) - # Test will fail without `set_model_tag`` here with error: + # Test will fail without set_model_tag here with error: # "ValueError: too many values to unpack (expected 3)" # This is because CompiledAttention and CompiledAttentionTwo # have different implmentations but the same torch.compile - # cache dir will be used by default + # cache dir will be used as default prefix is 'model_tag' with set_model_tag("attn_one"): self.attn_one = CompiledAttention( mlp_size=mlp_size, From d54221ff9a5c0c9115a765d57041f3bd0c4eee1d Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 16 Jul 2025 10:17:38 -0700 Subject: [PATCH 5/5] Randomize test inputs Signed-off-by: Yong Hoon Shin --- tests/compile/piecewise/test_multiple_graphs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index ae330b8ea0d..67b225a0fcc 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -245,12 +245,8 @@ class C(B): @torch.inference_mode -def run_model(vllm_config, model: nn.Module): +def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): with set_forward_context({}, vllm_config=vllm_config): - # Pre-allocate memory for CUDAGraph which expects - # static tensor addresses - inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() - # First run is for compile model(inputs) @@ -258,7 +254,6 @@ def run_model(vllm_config, model: nn.Module): model(inputs[:2]) model(inputs[:1]) - inputs[:2].fill_(1.0) output = model(inputs[:2]) output = output.cpu() @@ -286,6 +281,10 @@ def test_multi_graph_piecewise_compile(monkeypatch): vllm_config=vllm_config, prefix='').eval().cuda() + # Pre-allocate memory for CUDAGraph which expects + # static tensor addresses + inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() + with compilation_counter.expect( num_graphs_seen=2, # two graphs for the model num_piecewise_graphs_seen=6, @@ -297,7 +296,7 @@ def test_multi_graph_piecewise_compile(monkeypatch): num_cudagraph_captured=8, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append(run_model(vllm_config, model)) + outputs.append(run_model(vllm_config, model, inputs)) # no compile or cudagraph vllm_config = VllmConfig(compilation_config=CompilationConfig( @@ -316,6 +315,6 @@ def test_multi_graph_piecewise_compile(monkeypatch): num_backend_compilations=0, num_cudagraph_captured=0, ): - outputs.append(run_model(vllm_config, model)) + outputs.append(run_model(vllm_config, model, inputs)) assert torch.allclose(outputs[0], outputs[1])