Skip to content

Commit 7e8977f

Browse files
authored
[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
1 parent f1e840e commit 7e8977f

File tree

7 files changed

+120
-6
lines changed

7 files changed

+120
-6
lines changed

tests/plugins/vllm_add_dummy_platform/setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
entry_points={
1111
'vllm.platform_plugins': [
1212
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
13-
]
13+
],
14+
"vllm.general_plugins":
15+
["dummy_custom_ops = vllm_add_dummy_platform:register_ops"],
1416
})

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,7 @@
66

77
def dummy_platform_plugin() -> Optional[str]:
88
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
9+
10+
11+
def register_ops():
12+
import vllm_add_dummy_platform.dummy_custom_ops # noqa

tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from vllm.attention.backends.flash_attn import FlashAttentionBackend
4+
from vllm.attention.backends.placeholder_attn import (
5+
PlaceholderAttentionBackend)
56

67

7-
class DummyAttentionBackend(FlashAttentionBackend):
8+
class DummyAttentionBackend(PlaceholderAttentionBackend):
89

910
@staticmethod
1011
def get_name() -> str:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
7+
8+
9+
# Register CustomRotaryEmbedding to CustomOP.
10+
@RotaryEmbedding.register_oot
11+
class DummyRotaryEmbedding(RotaryEmbedding):
12+
"""Original rotary positional embedding."""
13+
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
16+
self.addition_config = True
17+
18+
def forward_oot(self, *args,
19+
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
20+
return super().forward_oot(*args, **kwargs)
Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import TYPE_CHECKING
34

4-
from vllm.platforms.cuda import CudaPlatform
5+
from vllm.platforms.interface import Platform, PlatformEnum
56

7+
if TYPE_CHECKING:
8+
from vllm.config import VllmConfig
9+
else:
10+
VllmConfig = None
11+
from vllm import envs
612

7-
class DummyPlatform(CudaPlatform):
13+
14+
class DummyPlatform(Platform):
15+
_enum = PlatformEnum.OOT
816
device_name = "DummyDevice"
17+
device_type: str = "privateuseone"
18+
dispatch_key: str = "PrivateUse1"
19+
20+
@classmethod
21+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
22+
if envs.VLLM_USE_V1:
23+
compilation_config = vllm_config.compilation_config
24+
# Activate custom ops for v1.
25+
compilation_config.custom_ops = ["all"]
926

1027
def get_attn_backend_cls(self, backend_name, head_size, dtype,
1128
kv_cache_dtype, block_size, use_v1, use_mla):
12-
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
29+
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

tests/plugins_tests/test_platform_plugins.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm.attention.selector import get_attn_backend
8+
from vllm.plugins import load_general_plugins
89
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
910

1011

@@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
3233
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
3334
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
3435
assert backend.get_name() == "Dummy_Backend"
36+
37+
38+
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
39+
# simulate workload by running an example
40+
load_general_plugins()
41+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
42+
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
43+
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
44+
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
45+
"possibly because the custom op is not registered correctly.")
46+
assert hasattr(layer, "addition_config"), (
47+
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
48+
"which is set by the custom op.")

vllm/model_executor/custom_op.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from typing import Optional
5+
46
import torch.nn as nn
57

68
from vllm.config import get_current_vllm_config
@@ -16,6 +18,24 @@ class CustomOp(nn.Module):
1618
Dispatches the forward method to the appropriate backend.
1719
"""
1820

21+
def __new__(cls, *args, **kwargs):
22+
try:
23+
op_name = cls.__name__
24+
except AttributeError:
25+
raise TypeError(
26+
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
27+
f"was not set, possibly because it was not decorated with "
28+
f"@CustomOp.register, or it's the CustomOp base class itself."
29+
) from None
30+
31+
if op_name not in cls.op_registry_oot:
32+
op_cls_to_instantiate = cls
33+
else:
34+
op_cls_to_instantiate = cls.op_registry_oot[op_name]
35+
logger.debug("Instantiating custom op: %s using %s", op_name,
36+
str(op_cls_to_instantiate))
37+
return super().__new__(op_cls_to_instantiate)
38+
1939
def __init__(self):
2040
super().__init__()
2141
self._forward_method = self.dispatch_forward()
@@ -138,6 +158,7 @@ def default_on() -> bool:
138158
# - MyOp.enabled()
139159
# - op_registry["my_op"].enabled()
140160
op_registry: dict[str, type['CustomOp']] = {}
161+
op_registry_oot: dict[str, type['CustomOp']] = {}
141162

142163
# Decorator to register custom ops.
143164
@classmethod
@@ -150,3 +171,38 @@ def decorator(op_cls):
150171
return op_cls
151172

152173
return decorator
174+
175+
# Decorator to register out-of-tree(oot) custom ops.
176+
# For OOT custom ops:
177+
# if in-tree layer class is registered with an oot_custom_op layer,
178+
# the oot_custom_op layer will be used instead.
179+
# Example:
180+
# - @UnquantizedFusedMoEMethod.register_oot
181+
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
182+
# or
183+
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
184+
@classmethod
185+
def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):
186+
187+
def decorator(op_cls):
188+
reg_name = name if name is not None else cls.__name__
189+
assert reg_name not in cls.op_registry_oot, \
190+
f"Duplicate op name: {reg_name}"
191+
op_cls.name = reg_name
192+
cls.op_registry_oot[reg_name] = op_cls
193+
return op_cls
194+
195+
if _decorated_op_cls is None:
196+
# Called with parentheses: @CustomOP.register_oot()
197+
# or @CustomOP.register_oot(name="...")
198+
# So, _decorated_op_cls is None.
199+
# We return the actual decorator function.
200+
return decorator
201+
elif isinstance(_decorated_op_cls, type): # Check if it's a class
202+
# Called without parentheses: @CustomOP.register_oot
203+
# The first argument is the class itself.
204+
# We call the 'decorator' function immediately with the class.
205+
return decorator(_decorated_op_cls)
206+
else:
207+
# Handle other unexpected cases if necessary
208+
raise TypeError("Decorator can only be applied to classes.")

0 commit comments

Comments
 (0)