|
6 | 6 | import pprint
|
7 | 7 | import time
|
8 | 8 | from collections.abc import Sequence
|
9 |
| -from contextlib import ExitStack |
10 | 9 | from typing import Any, Callable, Optional
|
11 |
| -from unittest.mock import patch |
12 | 10 |
|
13 | 11 | import torch
|
14 | 12 | import torch.fx as fx
|
15 | 13 |
|
16 | 14 | import vllm.envs as envs
|
17 | 15 | from vllm.config import CompilationConfig, VllmConfig
|
18 | 16 | from vllm.logger import init_logger
|
19 |
| -from vllm.utils import weak_ref_tensors |
| 17 | +from vllm.platforms import current_platform |
| 18 | +from vllm.utils import resolve_obj_by_qualname |
20 | 19 |
|
21 | 20 | from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
22 | 21 | InductorAdaptor, InductorStandaloneAdaptor)
|
23 | 22 | from .counter import compilation_counter
|
24 | 23 | from .inductor_pass import InductorPass
|
25 |
| -from .monitor import end_monitoring_torch_compile |
26 | 24 | from .pass_manager import PostGradPassManager
|
27 | 25 |
|
28 | 26 | logger = init_logger(__name__)
|
@@ -297,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target,
|
297 | 295 | num_graphs=len(self.compile_submod_names),
|
298 | 296 | runtime_shape=None)
|
299 | 297 |
|
300 |
| - self.module.__dict__[target] = PiecewiseBackend( |
| 298 | + piecewise_backend = resolve_obj_by_qualname( |
| 299 | + current_platform.get_piecewise_backend_cls()) |
| 300 | + self.module.__dict__[target] = piecewise_backend( |
301 | 301 | submod, self.vllm_config, self.graph_pool, index,
|
302 | 302 | len(self.compile_submod_names), sym_shape_indices,
|
303 | 303 | compiled_graph_for_general_shape, self.vllm_backend)
|
@@ -341,7 +341,7 @@ def __init__(
|
341 | 341 | ):
|
342 | 342 | global global_graph_pool
|
343 | 343 | if global_graph_pool is None:
|
344 |
| - global_graph_pool = torch.cuda.graph_pool_handle() |
| 344 | + global_graph_pool = current_platform.graph_pool_handle() |
345 | 345 |
|
346 | 346 | # TODO: in the future, if we want to use multiple
|
347 | 347 | # streams, it might not be safe to share a global pool.
|
@@ -558,197 +558,3 @@ def copy_and_call(*args):
|
558 | 558 | return self.split_gm(*list_args)
|
559 | 559 |
|
560 | 560 | return copy_and_call
|
561 |
| - |
562 |
| - |
563 |
| -@dataclasses.dataclass |
564 |
| -class ConcreteSizeEntry: |
565 |
| - runtime_shape: int |
566 |
| - need_to_compile: bool # the size is in compile_sizes |
567 |
| - use_cudagraph: bool # the size is in cudagraph_capture_sizes |
568 |
| - |
569 |
| - compiled: bool = False |
570 |
| - runnable: Callable = None # type: ignore |
571 |
| - num_finished_warmup: int = 0 |
572 |
| - cudagraph: Optional[torch.cuda.CUDAGraph] = None |
573 |
| - output: Optional[Any] = None |
574 |
| - |
575 |
| - # for cudagraph debugging, track the input addresses |
576 |
| - # during capture, and check if they are the same during replay |
577 |
| - input_addresses: Optional[list[int]] = None |
578 |
| - |
579 |
| - |
580 |
| -class PiecewiseBackend: |
581 |
| - |
582 |
| - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, |
583 |
| - graph_pool: Any, piecewise_compile_index: int, |
584 |
| - total_piecewise_compiles: int, sym_shape_indices: list[int], |
585 |
| - compiled_graph_for_general_shape: Callable, |
586 |
| - vllm_backend: VllmBackend): |
587 |
| - """ |
588 |
| - The backend for piecewise compilation. |
589 |
| - It mainly handles the compilation and cudagraph capturing. |
590 |
| -
|
591 |
| - We will compile `self.graph` once for the general shape, |
592 |
| - and then compile for different shapes specified in |
593 |
| - `compilation_config.compile_sizes`. |
594 |
| -
|
595 |
| - Independently, we will capture cudagraph for different shapes. |
596 |
| -
|
597 |
| - If a shape needs both compilation and cudagraph, we will |
598 |
| - compile it first, and then capture cudagraph. |
599 |
| - """ |
600 |
| - self.graph = graph |
601 |
| - self.vllm_config = vllm_config |
602 |
| - self.compilation_config = vllm_config.compilation_config |
603 |
| - self.graph_pool = graph_pool |
604 |
| - self.piecewise_compile_index = piecewise_compile_index |
605 |
| - self.total_piecewise_compiles = total_piecewise_compiles |
606 |
| - self.vllm_backend = vllm_backend |
607 |
| - |
608 |
| - self.is_first_graph = piecewise_compile_index == 0 |
609 |
| - self.is_last_graph = ( |
610 |
| - piecewise_compile_index == total_piecewise_compiles - 1) |
611 |
| - |
612 |
| - self.compile_sizes: set[int] = set( |
613 |
| - self.compilation_config.compile_sizes) |
614 |
| - self.cudagraph_capture_sizes: set[int] = set( |
615 |
| - self.compilation_config.cudagraph_capture_sizes |
616 |
| - ) if self.compilation_config.use_cudagraph else set() |
617 |
| - |
618 |
| - self.first_run_finished = False |
619 |
| - |
620 |
| - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa |
621 |
| - |
622 |
| - self.sym_shape_indices = sym_shape_indices |
623 |
| - |
624 |
| - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" |
625 |
| - |
626 |
| - # the entries for different shapes that we need to either |
627 |
| - # compile or capture cudagraph |
628 |
| - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} |
629 |
| - |
630 |
| - # to_be_compiled_sizes tracks the remaining sizes to compile, |
631 |
| - # and updates during the compilation process, so we need to copy it |
632 |
| - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() |
633 |
| - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): |
634 |
| - self.concrete_size_entries[shape] = ConcreteSizeEntry( |
635 |
| - runtime_shape=shape, |
636 |
| - need_to_compile=shape in self.compile_sizes, |
637 |
| - use_cudagraph=shape in self.cudagraph_capture_sizes, |
638 |
| - ) |
639 |
| - |
640 |
| - def check_for_ending_compilation(self): |
641 |
| - if self.is_last_graph and not self.to_be_compiled_sizes: |
642 |
| - # no specific sizes to compile |
643 |
| - # save the hash of the inductor graph for the next run |
644 |
| - self.vllm_backend.compiler_manager.save_to_file() |
645 |
| - end_monitoring_torch_compile(self.vllm_config) |
646 |
| - |
647 |
| - def __call__(self, *args) -> Any: |
648 |
| - if not self.first_run_finished: |
649 |
| - self.first_run_finished = True |
650 |
| - self.check_for_ending_compilation() |
651 |
| - return self.compiled_graph_for_general_shape(*args) |
652 |
| - |
653 |
| - runtime_shape = args[self.sym_shape_indices[0]] |
654 |
| - if runtime_shape not in self.concrete_size_entries: |
655 |
| - # we don't need to do anything for this shape |
656 |
| - return self.compiled_graph_for_general_shape(*args) |
657 |
| - |
658 |
| - entry = self.concrete_size_entries[runtime_shape] |
659 |
| - |
660 |
| - if entry.runnable is None: |
661 |
| - entry.runnable = self.compiled_graph_for_general_shape |
662 |
| - |
663 |
| - if entry.need_to_compile and not entry.compiled: |
664 |
| - entry.compiled = True |
665 |
| - self.to_be_compiled_sizes.remove(runtime_shape) |
666 |
| - # args are real arguments |
667 |
| - entry.runnable = self.vllm_backend.compiler_manager.compile( |
668 |
| - self.graph, |
669 |
| - args, |
670 |
| - self.compilation_config.inductor_compile_config, |
671 |
| - self.compilation_config, |
672 |
| - graph_index=self.piecewise_compile_index, |
673 |
| - num_graphs=self.total_piecewise_compiles, |
674 |
| - runtime_shape=runtime_shape) |
675 |
| - |
676 |
| - # finished compilations for all required shapes |
677 |
| - if self.is_last_graph and not self.to_be_compiled_sizes: |
678 |
| - self.check_for_ending_compilation() |
679 |
| - |
680 |
| - if not entry.use_cudagraph: |
681 |
| - return entry.runnable(*args) |
682 |
| - |
683 |
| - if entry.cudagraph is None: |
684 |
| - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa |
685 |
| - entry.num_finished_warmup += 1 |
686 |
| - if self.is_first_graph: |
687 |
| - logger.debug( |
688 |
| - "Warming up %s/%s for shape %s", |
689 |
| - entry.num_finished_warmup, |
690 |
| - self.compilation_config.cudagraph_num_of_warmups, |
691 |
| - runtime_shape) |
692 |
| - return entry.runnable(*args) |
693 |
| - |
694 |
| - if self.is_first_graph: |
695 |
| - # Since we capture cudagraph for many different shapes and |
696 |
| - # capturing is fast, we don't need to log it for every shape. |
697 |
| - # We only log it in the debug mode. |
698 |
| - logger.debug("Capturing a cudagraph for shape %s", |
699 |
| - runtime_shape) |
700 |
| - |
701 |
| - input_addresses = [ |
702 |
| - x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
703 |
| - ] |
704 |
| - entry.input_addresses = input_addresses |
705 |
| - cudagraph = torch.cuda.CUDAGraph() |
706 |
| - |
707 |
| - with ExitStack() as stack: |
708 |
| - if not self.is_first_graph: |
709 |
| - # during every model forward, we will capture |
710 |
| - # many pieces of cudagraphs (roughly one per layer). |
711 |
| - # running gc again and again across layers will |
712 |
| - # make the cudagraph capture very slow. |
713 |
| - # therefore, we only run gc for the first graph, |
714 |
| - # and disable gc for the rest of the graphs. |
715 |
| - stack.enter_context(patch("gc.collect", lambda: None)) |
716 |
| - stack.enter_context( |
717 |
| - patch("torch.cuda.empty_cache", lambda: None)) |
718 |
| - |
719 |
| - # mind-exploding: carefully manage the reference and memory. |
720 |
| - with torch.cuda.graph(cudagraph, pool=self.graph_pool): |
721 |
| - # `output` is managed by pytorch's cudagraph pool |
722 |
| - output = entry.runnable(*args) |
723 |
| - if self.is_last_graph: |
724 |
| - # by converting it to weak ref, |
725 |
| - # the original `output` will immediately be released |
726 |
| - # to save memory. It is only safe to do this for |
727 |
| - # the last graph, because the output of the last graph |
728 |
| - # will not be used by any other cuda graph. |
729 |
| - output = weak_ref_tensors(output) |
730 |
| - |
731 |
| - # here we always use weak ref for the output |
732 |
| - # to save memory |
733 |
| - entry.output = weak_ref_tensors(output) |
734 |
| - entry.cudagraph = cudagraph |
735 |
| - |
736 |
| - compilation_counter.num_cudagraph_caputured += 1 |
737 |
| - |
738 |
| - # important: we need to return the output, rather than |
739 |
| - # the weak ref of the output, so that pytorch can correctly |
740 |
| - # manage the memory during cuda graph capture |
741 |
| - return output |
742 |
| - |
743 |
| - if self.is_debugging_mode: |
744 |
| - # check if the input addresses are the same |
745 |
| - new_input_addresses = [ |
746 |
| - x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
747 |
| - ] |
748 |
| - assert new_input_addresses == entry.input_addresses, ( |
749 |
| - "Input addresses for cudagraphs are different during replay." |
750 |
| - f" Expected {entry.input_addresses}, got {new_input_addresses}" |
751 |
| - ) |
752 |
| - |
753 |
| - entry.cudagraph.replay() |
754 |
| - return entry.output |
0 commit comments