Skip to content

Commit d70bc7c

Browse files
authored
[torch.compile] reorganize the cache directory to support compiling multiple models (#19064)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent ce688ad commit d70bc7c

File tree

6 files changed

+117
-27
lines changed

6 files changed

+117
-27
lines changed

vllm/compilation/backends.py

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import time
99
from collections.abc import Sequence
10+
from contextlib import contextmanager
1011
from typing import Any, Callable, Optional
1112

1213
import torch
@@ -66,7 +67,25 @@ def __init__(self, compilation_config: CompilationConfig):
6667
def compute_hash(self, vllm_config: VllmConfig) -> str:
6768
return self.compiler.compute_hash(vllm_config)
6869

69-
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
70+
def initialize_cache(self,
71+
cache_dir: str,
72+
disable_cache: bool = False,
73+
prefix: str = ""):
74+
"""
75+
Initialize the cache directory for the compiler.
76+
77+
The organization of the cache directory is as follows:
78+
cache_dir=/path/to/hash_str/rank_i_j/prefix/
79+
inside cache_dir, there will be:
80+
- vllm_compile_cache.py
81+
- computation_graph.py
82+
- transformed_code.py
83+
84+
for multiple prefixes, they can share the same
85+
base cache dir of /path/to/hash_str/rank_i_j/ ,
86+
to store some common compilation artifacts.
87+
"""
88+
7089
self.disable_cache = disable_cache
7190
self.cache_dir = cache_dir
7291
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
@@ -80,7 +99,8 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
8099
self.cache = ast.literal_eval(f.read())
81100

82101
self.compiler.initialize_cache(cache_dir=cache_dir,
83-
disable_cache=disable_cache)
102+
disable_cache=disable_cache,
103+
prefix=prefix)
84104

85105
def save_to_file(self):
86106
if self.disable_cache or not self.is_cache_updated:
@@ -310,6 +330,25 @@ def call_module(self, target: torch.fx.node.Target,
310330
return output
311331

312332

333+
# the tag for the part of model being compiled,
334+
# e.g. backbone/eagle_head
335+
model_tag: str = "backbone"
336+
337+
338+
@contextmanager
339+
def set_model_tag(tag: str):
340+
"""Context manager to set the model tag."""
341+
global model_tag
342+
assert tag != model_tag, \
343+
f"Model tag {tag} is the same as the current tag {model_tag}."
344+
old_tag = model_tag
345+
model_tag = tag
346+
try:
347+
yield
348+
finally:
349+
model_tag = old_tag
350+
351+
313352
class VllmBackend:
314353
"""The compilation backend for `torch.compile` with vLLM.
315354
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -341,7 +380,17 @@ class VllmBackend:
341380
def __init__(
342381
self,
343382
vllm_config: VllmConfig,
383+
prefix: str = "",
344384
):
385+
386+
# if the model is initialized with a non-empty prefix,
387+
# then usually it's enough to use that prefix,
388+
# e.g. launguage_model, vision_model, etc.
389+
# when multiple parts are initialized as independent
390+
# models, we need to use the model_tag to distinguish
391+
# them, e.g. backbone (default), eagle_head, etc.
392+
self.prefix = prefix or model_tag
393+
345394
global global_graph_pool
346395
if global_graph_pool is None:
347396
global_graph_pool = current_platform.graph_pool_handle()
@@ -441,16 +490,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
441490
)
442491
self.compilation_config.cache_dir = cache_dir
443492

444-
if compilation_counter.num_graphs_seen > 0:
445-
cache_dir = self.compilation_config.cache_dir + \
446-
f'-{compilation_counter.num_graphs_seen}'
447-
else:
448-
cache_dir = self.compilation_config.cache_dir
493+
cache_dir = self.compilation_config.cache_dir
449494
os.makedirs(cache_dir, exist_ok=True)
450495
self.compilation_config.cache_dir = cache_dir
451496
rank = vllm_config.parallel_config.rank
452497
dp_rank = vllm_config.parallel_config.data_parallel_rank
453-
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
498+
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
499+
self.prefix)
454500
os.makedirs(local_cache_dir, exist_ok=True)
455501
self.compilation_config.local_cache_dir = local_cache_dir
456502

@@ -462,7 +508,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
462508
logger.info("Using cache directory: %s for vLLM's torch.compile",
463509
local_cache_dir)
464510

465-
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)
511+
self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
512+
self.prefix)
466513

467514
# when dynamo calls the backend, it means the bytecode
468515
# transform and analysis are done

vllm/compilation/compiler_interface.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,22 @@ class CompilerInterface:
2828
# This is a class-level attribute.
2929
name: str
3030

31-
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
31+
def initialize_cache(self,
32+
cache_dir: str,
33+
disable_cache: bool = False,
34+
prefix: str = ""):
3235
"""
3336
when the vLLM process uses `cache_dir` as the cache directory,
3437
the compiler should initialize itself with the cache directory,
3538
e.g. by re-directing its own cache directory to a sub-directory.
39+
40+
prefix can be used in combination with cache_dir to figure out the base
41+
cache directory, e.g. there're multiple parts of model being compiled,
42+
but we want to share the same cache directory for all of them.
43+
44+
e.g.
45+
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
46+
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
3647
"""
3748
pass
3849

@@ -166,7 +177,10 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
166177
usedforsecurity=False).hexdigest()[:10]
167178
return hash_str
168179

169-
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
180+
def initialize_cache(self,
181+
cache_dir: str,
182+
disable_cache: bool = False,
183+
prefix: str = ""):
170184
self.cache_dir = cache_dir
171185

172186
def compile(
@@ -242,18 +256,23 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
242256
usedforsecurity=False).hexdigest()[:10]
243257
return hash_str
244258

245-
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
259+
def initialize_cache(self,
260+
cache_dir: str,
261+
disable_cache: bool = False,
262+
prefix: str = ""):
246263
self.cache_dir = cache_dir
264+
self.prefix = prefix
265+
self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir
247266
if disable_cache:
248267
return
249268
# redirect the cache directory to a sub-directory
250269
# set flags so that Inductor and Triton store their cache
251270
# in the cache_dir, then users only need to copy the cache_dir
252271
# to another machine to reuse the cache.
253-
inductor_cache = os.path.join(cache_dir, "inductor_cache")
272+
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
254273
os.makedirs(inductor_cache, exist_ok=True)
255274
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
256-
triton_cache = os.path.join(cache_dir, "triton_cache")
275+
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
257276
os.makedirs(triton_cache, exist_ok=True)
258277
os.environ["TRITON_CACHE_DIR"] = triton_cache
259278

@@ -298,14 +317,14 @@ def hijack_load(*args, **kwargs):
298317
nonlocal file_path
299318
compiled_fn = inductor_compiled_graph.current_callable
300319
file_path = compiled_fn.__code__.co_filename # noqa
301-
if not file_path.startswith(self.cache_dir):
320+
if not file_path.startswith(self.base_cache_dir):
302321
# hooked in the align_inputs_from_check_idxs function
303322
# in torch/_inductor/utils.py
304323
for cell in compiled_fn.__closure__:
305324
if not callable(cell.cell_contents):
306325
continue
307326
if cell.cell_contents.__code__.co_filename.startswith(
308-
self.cache_dir):
327+
self.base_cache_dir):
309328
# this is the real file path compiled from Inductor
310329
file_path = cell.cell_contents.__code__.co_filename
311330
break
@@ -325,14 +344,15 @@ def hijacked_compile_fx_inner(*args, **kwargs):
325344
nonlocal file_path
326345
compiled_fn = inductor_compiled_graph.current_callable
327346
file_path = compiled_fn.__code__.co_filename # noqa
328-
if not file_path.startswith(self.cache_dir):
347+
if not file_path.startswith(self.base_cache_dir):
329348
# hooked in the align_inputs_from_check_idxs function
330349
# in torch/_inductor/utils.py
331350
for cell in compiled_fn.__closure__:
332351
if not callable(cell.cell_contents):
333352
continue
334353
code = cell.cell_contents.__code__
335-
if code.co_filename.startswith(self.cache_dir):
354+
if code.co_filename.startswith(
355+
self.base_cache_dir):
336356
# this is the real file path
337357
# compiled from Inductor
338358
file_path = code.co_filename

vllm/config.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4666,23 +4666,28 @@ def __str__(self):
46664666

46674667

46684668
_current_vllm_config: Optional[VllmConfig] = None
4669+
_current_prefix: Optional[str] = None
46694670

46704671

46714672
@contextmanager
4672-
def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
4673+
def set_current_vllm_config(vllm_config: VllmConfig,
4674+
check_compile=False,
4675+
prefix: Optional[str] = None):
46734676
"""
46744677
Temporarily set the current vLLM config.
46754678
Used during model initialization.
46764679
We save the current vLLM config in a global variable,
46774680
so that all modules can access it, e.g. custom ops
46784681
can access the vLLM config to determine how to dispatch.
46794682
"""
4680-
global _current_vllm_config
4683+
global _current_vllm_config, _current_prefix
46814684
old_vllm_config = _current_vllm_config
4685+
old_prefix = _current_prefix
46824686
from vllm.compilation.counter import compilation_counter
46834687
num_models_seen = compilation_counter.num_models_seen
46844688
try:
46854689
_current_vllm_config = vllm_config
4690+
_current_prefix = prefix
46864691
yield
46874692
except Exception:
46884693
raise
@@ -4706,6 +4711,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False):
47064711
vllm_config.model_config.model)
47074712
finally:
47084713
_current_vllm_config = old_vllm_config
4714+
_current_prefix = old_prefix
47094715

47104716

47114717
def get_current_vllm_config() -> VllmConfig:
@@ -4719,6 +4725,15 @@ def get_current_vllm_config() -> VllmConfig:
47194725
return _current_vllm_config
47204726

47214727

4728+
def get_current_model_prefix() -> str:
4729+
"""
4730+
Get the prefix of the model that's currently being initialized.
4731+
"""
4732+
assert _current_prefix is not None, \
4733+
"Current model prefix is not set. "
4734+
return _current_prefix
4735+
4736+
47224737
def contains_object_print(text):
47234738
"""
47244739
Check if the text looks like a printed Python object, e.g.

vllm/model_executor/model_loader/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def initialize_model(
5858
all_params = [param.name for param in signatures.parameters.values()]
5959
if "vllm_config" in all_params and "prefix" in all_params:
6060
# new-style model class
61-
with set_current_vllm_config(vllm_config, check_compile=True):
61+
with set_current_vllm_config(vllm_config,
62+
check_compile=True,
63+
prefix=prefix):
6264
return model_class(vllm_config=vllm_config, prefix=prefix)
6365

6466
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
@@ -86,7 +88,9 @@ def initialize_model(
8688
kwargs["lora_config"] = vllm_config.lora_config
8789
if "scheduler_config" in all_params:
8890
kwargs["scheduler_config"] = vllm_config.scheduler_config
89-
with set_current_vllm_config(vllm_config, check_compile=True):
91+
with set_current_vllm_config(vllm_config,
92+
check_compile=True,
93+
prefix=prefix):
9094
return model_class(**kwargs)
9195

9296

vllm/v1/spec_decode/eagle.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,10 @@ def load_model(self, target_model: nn.Module) -> None:
320320
target_attn_layer_names = set(
321321
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
322322

323-
self.model = get_model(vllm_config=self.vllm_config,
324-
model_config=draft_model_config)
323+
from vllm.compilation.backends import set_model_tag
324+
with set_model_tag("eagle_head"):
325+
self.model = get_model(vllm_config=self.vllm_config,
326+
model_config=draft_model_config)
325327

326328
draft_attn_layer_names = (
327329
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -

vllm/v1/spec_decode/medusa.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def propose(
4848
return [list(row) for row in zip(*draft_tokens)]
4949

5050
def load_model(self, target_model: nn.Module) -> None:
51-
self.model = get_model(vllm_config=self.vllm_config,
52-
model_config=self.vllm_config.
53-
speculative_config.draft_model_config)
51+
from vllm.compilation.backends import set_model_tag
52+
with set_model_tag("medusa_head"):
53+
self.model = get_model(vllm_config=self.vllm_config,
54+
model_config=self.vllm_config.
55+
speculative_config.draft_model_config)
5456

5557
@torch.inference_mode()
5658
def dummy_run(self, num_tokens: int) -> None:

0 commit comments

Comments
 (0)