|
| 1 | +from typing import Callable, List, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.library import Library |
| 5 | +from vllm import utils |
| 6 | +from vllm.utils import vllm_lib |
| 7 | + |
| 8 | + |
| 9 | +def ascend_direct_register_custom_op( |
| 10 | + op_name: str, |
| 11 | + op_func: Callable, |
| 12 | + mutates_args: list[str], |
| 13 | + fake_impl: Optional[Callable] = None, |
| 14 | + target_lib: Optional[Library] = None, |
| 15 | + dispatch_key: str = "CUDA", |
| 16 | + tags: Tuple[torch.Tag, ...] = (), |
| 17 | +): |
| 18 | + # In pytorch 2.5.1, torch.library.infer_schema require the input function to |
| 19 | + # have annotations supported by typing library. But in pytorch 2.7.0 which |
| 20 | + # vllm using, torch.library.infer_schema require the python builtin type. In |
| 21 | + # this case, we should revert built type to typing type for 2.5.1 backward |
| 22 | + # compatibility. |
| 23 | + for k, v in op_func.__annotations__.items(): |
| 24 | + if v == list[int]: |
| 25 | + op_func.__annotations__[k] = List[int] |
| 26 | + if v == Optional[list[int]]: |
| 27 | + op_func.__annotations__[k] = Optional[List[int]] |
| 28 | + # TODO: add more type convert here if needed. |
| 29 | + import torch.library |
| 30 | + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) |
| 31 | + my_lib = target_lib or vllm_lib |
| 32 | + my_lib.define(op_name + schema_str, tags=tags) |
| 33 | + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) |
| 34 | + if fake_impl is not None: |
| 35 | + my_lib._register_fake(op_name, fake_impl) |
| 36 | + |
| 37 | + |
| 38 | +utils.direct_register_custom_op = ascend_direct_register_custom_op |
0 commit comments