Skip to content

Add test case for compiling multiple graphs #21044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
321 changes: 321 additions & 0 deletions tests/compile/piecewise/test_multiple_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test (piecewise) compilation with a simple model where multiple submodules
are compiled and graph captured separately.
"""
import torch
from torch import nn
from torch.library import Library

from vllm.compilation.backends import set_model_tag
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import (ignore_torch_compile,
support_torch_compile)
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
silly_lib = Library("silly", "FRAGMENT") # noqa

BATCH_SIZE = 32
MLP_SIZE = 128
HIDDEN_SIZE = 1024
RANDOM_SEED = 0


def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
out.copy_(q)
out += k
out += v


def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out: torch.Tensor) -> None:
return


direct_register_custom_op(
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
)


@support_torch_compile
class ParentModel(nn.Module):

def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class Attention(nn.Module):

def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False)
self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False)

# Initialize to same weights for testing
nn.init.xavier_normal_(
self.pre_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)
nn.init.xavier_normal_(
self.post_attn.weight.data,
generator=torch.Generator().manual_seed(RANDOM_SEED),
gain=0.001)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = self.post_attn(x)
return x


@support_torch_compile
class CompiledAttention(nn.Module):

def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x)


@support_torch_compile
class CompiledAttentionTwo(CompiledAttention):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x


@ignore_torch_compile
class SimpleModelWithTwoGraphs(ParentModel):

def __init__(self,
*,
mlp_size: int,
hidden_size: int,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Test will fail without `set_model_tag`` here with error:
# "ValueError: too many values to unpack (expected 3)"
# This is because CompiledAttention and CompiledAttentionTwo
# have different implmentations but the same torch.compile
# cache dir will be used by default
with set_model_tag("attn_one"):
self.attn_one = CompiledAttention(
mlp_size=mlp_size,
hidden_size=hidden_size,
vllm_config=vllm_config,
prefix=f"{prefix}.attn_one",
)
with set_model_tag("attn_two"):
self.attn_two = CompiledAttentionTwo(
mlp_size=mlp_size,
hidden_size=hidden_size,
vllm_config=vllm_config,
prefix=f"{prefix}.attn_two",
)

self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()

def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz = x.shape[0]
# CUDAGraph expects same tensor addresses for each run
self.hidden_states[:bsz].copy_(x)
x = self.attn_one(self.hidden_states[:bsz])
self.hidden_states[:bsz].copy_(x)
x = self.attn_two(self.hidden_states[:bsz])
return x


def test_ignore_torch_compile_decorator(monkeypatch):
assert VLLM_USE_V1

monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

# piecewise
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))

@support_torch_compile
class A(nn.Module):

def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + x
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = x * 3
return x

@ignore_torch_compile
class B(A):
...

@support_torch_compile
class C(B):
...

with set_current_vllm_config(vllm_config):
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()

# A has support_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
# first run is for compile
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
# run cudagraph captured sizes
mod_A(torch.randn(2, MLP_SIZE).cuda())
mod_A(torch.randn(1, MLP_SIZE).cuda())

with set_current_vllm_config(vllm_config):
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()

# B's ignore_torch_compile should override A's support_torch_compile
with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
), set_forward_context({}, vllm_config=vllm_config):
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
mod_B(torch.randn(2, MLP_SIZE).cuda())
mod_B(torch.randn(1, MLP_SIZE).cuda())

with set_current_vllm_config(vllm_config):
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()

# C's support_torch_compile should override B's ignore_torch_compile
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=3,
num_piecewise_capturable_graphs_seen=2,
num_backend_compilations=2,
num_cudagraph_captured=4,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
mod_C(torch.randn(2, MLP_SIZE).cuda())
mod_C(torch.randn(1, MLP_SIZE).cuda())


@torch.inference_mode
def run_model(vllm_config, model: nn.Module):
with set_forward_context({}, vllm_config=vllm_config):
# Pre-allocate memory for CUDAGraph which expects
# static tensor addresses
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()

# First run is for compile
model(inputs)

# Run CUDAGraph captured sizes
model(inputs[:2])
model(inputs[:1])

inputs[:2].fill_(1.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For test, should we random initialize inputs and check bitwise equivalence?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to randomize test input each test run. By bitwise equivalence do you mean torch.equal() on final output or something else? Using torch.equal (rtol=0,atol=0) fails currently, but that might be expected?

output = model(inputs[:2])

output = output.cpu()
return output.cpu()


def test_multi_graph_piecewise_compile(monkeypatch):
assert VLLM_USE_V1

monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I noticed that when I remove this and enable compile caching, the compile for CompiledAttentionTwo fails when running pytest tests/compile/piecewise/test_multiple_graphs.py::test_multi_graph_piecewise_compile. The error is:

E               torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7f8addaf6cf0>' raised:
E               RuntimeError: vLLM failed to compile the model. The most likely reason for this is that a previous compilation failed, leading to a corrupted compilation artifact. We recommend trying to remove ~/.cache/vllm/torch_compile_cache and try again to see the real issue. 
E               
E               Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

I found that when I keep the prefix in inductor compile cache dir here, the test passes. Before this change, inductor cache dir would be /home/yhshin/.cache/vllm/torch_compile_cache/bbbe1d91c9/rank_0_0 for both attn_one and attn_two, but after the change these are separate (i.e. /home/yhshin/.cache/vllm/torch_compile_cache/bbbe1d91c9/rank_0_0/{prefix}).

For compiling multiple graphs in a single model, should we be setting inductor cache dir to separate directories?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we're compiling multiple graphs in a single model, each graph should be tagged with a different model tag. That will make it so that their intermediate artifacts are stored in a separate subfolder of the cache directory.

Are you saying that even with the set_model_tags, you need to set VLLM_DISABLE_COMPILE_CACHE=1?

Copy link
Collaborator Author

@sarckk sarckk Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying that even with the set_model_tags, you need to set VLLM_DISABLE_COMPILE_CACHE=1?

Yes, and it seems to be due to shared inductor cache dir. If I force it to be stored in separate subfolder according to the model tag, it runs fine.

Having them share inductor cache seems to be a conscious decision in #19064:

This PR re-organizes the cache directory structure, so that the same vLLM instances will use the same TORCHINDUCTOR_CACHE_DIR and TRITON_CACHE_DIR, but just different storage for vllm_compile_cache.py etc.


outputs = []

# piecewise compile
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=[1, 2],
))

with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()

with compilation_counter.expect(
num_graphs_seen=2, # two graphs for the model
num_piecewise_graphs_seen=6,
# attn_one, attn_two each has 3 piecewise graphs
# (pre_attn, post_attn, silly_attention) each
num_piecewise_capturable_graphs_seen=4,
# attn_one, attn_two has pre_attn and post_attn each, total=4
num_backend_compilations=4, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=8,
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(run_model(vllm_config, model))

# no compile or cudagraph
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.NO_COMPILATION, ))

with set_current_vllm_config(vllm_config):
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
vllm_config=vllm_config,
prefix='').eval().cuda()

with compilation_counter.expect(
num_graphs_seen=0,
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_captured=0,
):
outputs.append(run_model(vllm_config, model))

assert torch.allclose(outputs[0], outputs[1])
35 changes: 34 additions & 1 deletion vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,38 @@

logger = init_logger(__name__)

IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

_T = TypeVar("_T", bound=type[nn.Module])


def ignore_torch_compile(cls: _T) -> _T:
"""
A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has
a support_torch_compile decorator, but we don't want to
compile the class `cls` that inherits the parent class.
This only ignores compiling the forward of the class the
decorator is applied to.

If the parent has ignore_torch_compile but the child has
support_torch_compile, the child will still be compiled.

If the class has one or more submodules
that have support_torch_compile decorator applied, compile will
not be ignored for those submodules.
"""
setattr(cls, IGNORE_COMPILE_KEY, True)
return cls


def _should_ignore_torch_compile(cls) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
return getattr(cls, IGNORE_COMPILE_KEY, False)


@overload
def support_torch_compile(
*,
Expand Down Expand Up @@ -148,6 +177,8 @@ def _support_torch_compile(

old_init = cls.__init__

setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
Expand All @@ -156,9 +187,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
] or not supports_dynamo() or _should_ignore_torch_compile(
self.__class__)
if self.do_not_compile:
return

compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
Expand Down