Skip to content

Commit aafabaa

Browse files
[Fix][torch.compile] Enable custom ops by default when Inductor off (#20102)
Signed-off-by: luka <luka@neuralmagic.com>
1 parent 94a55c7 commit aafabaa

File tree

3 files changed

+41
-43
lines changed

3 files changed

+41
-43
lines changed

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
2828

2929

3030
@pytest.mark.parametrize(
31-
"env, torch_level, ops_enabled, default_on",
31+
"env, torch_level, use_inductor, ops_enabled, default_on",
3232
[
3333
# Default values based on compile level
34-
("", 0, [True] * 4, True),
35-
("", 1, [True] * 4, True),
36-
("", 2, [True] * 4, True), # All by default
37-
("", 3, [False] * 4, False),
38-
("", 4, [False] * 4, False), # None by default
34+
# - All by default (no Inductor compilation)
35+
("", 0, False, [True] * 4, True),
36+
("", 1, True, [True] * 4, True),
37+
("", 2, False, [True] * 4, True),
38+
# - None by default (with Inductor)
39+
("", 3, True, [False] * 4, False),
40+
("", 4, True, [False] * 4, False),
41+
# - All by default (without Inductor)
42+
("", 3, False, [True] * 4, True),
43+
("", 4, False, [True] * 4, True),
3944
# Explicitly enabling/disabling
4045
#
4146
# Default: all
4247
#
4348
# All but SiluAndMul
44-
("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True),
49+
("+rms_norm,-silu_and_mul", 0, True, [1, 0, 1, 1], True),
4550
# Only ReLU3
46-
("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False),
51+
("none,-rms_norm,+relu3", 1, False, [0, 0, 0, 1], False),
4752
# All but SiluAndMul
48-
("all,-silu_and_mul", 1, [1, 0, 1, 1], True),
53+
("all,-silu_and_mul", 2, True, [1, 0, 1, 1], True),
4954
# All but ReLU3 (even if ReLU2 is on)
50-
("-relu3,relu2", 1, [1, 1, 1, 0], True),
51-
# GeluAndMul and SiluAndMul
52-
("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False),
55+
("-relu3,relu2", 3, False, [1, 1, 1, 0], True),
56+
# RMSNorm and SiluAndMul
57+
("none,-relu3,+rms_norm,+silu_and_mul", 4, False, [1, 1, 0, 0], False),
5358
# All but RMSNorm
54-
("-rms_norm", 2, [0, 1, 1, 1], True),
59+
("-rms_norm", 3, False, [0, 1, 1, 1], True),
5560
#
5661
# Default: none
5762
#
5863
# Only ReLU3
59-
("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False),
64+
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
6065
# All but RMSNorm
61-
("all,-rms_norm", 4, [0, 1, 1, 1], True),
66+
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
6267
])
63-
def test_enabled_ops(env: str, torch_level: int, ops_enabled: list[int],
64-
default_on: bool):
65-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
66-
level=torch_level, custom_ops=env.split(",")))
68+
def test_enabled_ops(env: str, torch_level: int, use_inductor: bool,
69+
ops_enabled: list[int], default_on: bool):
70+
vllm_config = VllmConfig(
71+
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
72+
level=torch_level,
73+
custom_ops=env.split(",")))
6774
with set_current_vllm_config(vllm_config):
6875
assert CustomOp.default_on() == default_on
6976

vllm/config.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3994,7 +3994,8 @@ class CompilationConfig:
39943994
- 'none,+op1,+op2' to enable only op1 and op2
39953995
39963996
By default, all custom ops are enabled when running without Inductor and
3997-
disabled when running with Inductor (compile_level >= Inductor)."""
3997+
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
3998+
Inductor generates (fused) Triton kernels for disabled custom ops."""
39983999
splitting_ops: list[str] = field(default_factory=list)
39994000
"""A list of ops to split the full graph into subgraphs, used in piecewise
40004001
compilation."""
@@ -4003,10 +4004,13 @@ class CompilationConfig:
40034004
use_inductor: bool = True
40044005
"""Whether to use inductor compilation:
40054006
4006-
- False: inductor compilation is not used. graph runs in eager.
4007-
- True: inductor compilation is used. one graph for symbolic shape
4008-
is compiled. In addition, compile for compile_sizes,
4009-
using configurations in inductor_compile_config."""
4007+
- False: inductor compilation is not used. graph runs in eager
4008+
(custom_ops enabled by default).
4009+
- True: inductor compilation is used (custom_ops disabled by default).
4010+
One graph for symbolic shape and one graph per size in compile_sizes
4011+
are compiled using configurations in inductor_compile_config.
4012+
4013+
This setting is ignored if level<PIECEWISE."""
40104014
compile_sizes: Optional[list[Union[int, str]]] = None
40114015
"""Sizes to compile for inductor. In addition
40124016
to integers, it also supports "cudagraph_capture_sizes" to
@@ -4537,19 +4541,6 @@ def __post_init__(self):
45374541
self.compilation_config.level = CompilationLevel.PIECEWISE
45384542
self.compilation_config.set_splitting_ops_for_v1()
45394543

4540-
# The behavior of custom ops with inductor depends on the config:
4541-
# - If use_inductor=True and custom_ops is empty:
4542-
# Inductor generates Triton kernels for all registered custom ops
4543-
# (default behavior)
4544-
# - If use_inductor=True and custom_ops is non-empty:
4545-
# Custom CUDA kernels are used for specified ops while inductor
4546-
# generates Triton kernels for remaining ops, including misc torch
4547-
# ops in the model.
4548-
if (not self.compilation_config.custom_ops
4549-
and self.compilation_config.use_inductor):
4550-
# Let inductor generate Triton kernels for the custom ops.
4551-
self.compilation_config.custom_ops = ["none"]
4552-
45534544
self._set_cudagraph_sizes()
45544545

45554546
if self.cache_config.cpu_offload_gb > 0 and \

vllm/model_executor/custom_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,16 @@ def enabled(cls) -> bool:
141141
@staticmethod
142142
def default_on() -> bool:
143143
"""
144-
On by default if level < CompilationLevel.PIECEWISE
144+
On by default if PyTorch Inductor is not used.
145145
Specifying 'all' or 'none' in custom_op takes precedence.
146146
"""
147147
from vllm.config import CompilationLevel
148148
compilation_config = get_current_vllm_config().compilation_config
149-
custom_ops = compilation_config.custom_ops
150-
count_none = custom_ops.count("none")
151-
count_all = custom_ops.count("all")
152-
return compilation_config.level < CompilationLevel.PIECEWISE and \
153-
not count_none > 0 or count_all > 0
149+
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
150+
or not compilation_config.use_inductor)
151+
count_none = compilation_config.custom_ops.count("none")
152+
count_all = compilation_config.custom_ops.count("all")
153+
return default_on and not count_none > 0 or count_all > 0
154154

155155
# Dictionary of all custom ops (classes, indexed by registered name).
156156
# To check if an op with a name is enabled, call .enabled() on the class.

0 commit comments

Comments
 (0)