Skip to content

Commit 857f489

Browse files
authored
[CI] Patch torch.library.infer_schema for torch 2.5 backward compatibility (#837)
Patch torch.library.infer_schema for torch 2.5 backward compatibility - Introduced a new module `patch_utils` under `vllm_ascend/patch/worker/patch_common/`. - Added a function `ascend_direct_register_custom_op` to handle custom operator registration with backward compatibility for PyTorch < 2.7 (such as torch 2.5.1). - Implemented type conversion logic for annotations to ensure compatibility across different PyTorch versions. - Registered the function `ascend_direct_register_custom_op` to `utils.direct_register_custom_op`. - Updated `__init__.py` to include `patch_utils` as the first import. - Ensured `patch_utils` is available for use in other patch files and skipped isort checks for `patch_utils` import. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent e564470 commit 857f489

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
17+
# patch_utils should be the first import, because it will be used by other
18+
# patch files.
19+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
1820
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
1921
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2022
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Callable, List, Optional, Tuple
2+
3+
import torch
4+
from torch.library import Library
5+
from vllm import utils
6+
from vllm.utils import vllm_lib
7+
8+
9+
def ascend_direct_register_custom_op(
10+
op_name: str,
11+
op_func: Callable,
12+
mutates_args: list[str],
13+
fake_impl: Optional[Callable] = None,
14+
target_lib: Optional[Library] = None,
15+
dispatch_key: str = "CUDA",
16+
tags: Tuple[torch.Tag, ...] = (),
17+
):
18+
# In pytorch 2.5.1, torch.library.infer_schema require the input function to
19+
# have annotations supported by typing library. But in pytorch 2.7.0 which
20+
# vllm using, torch.library.infer_schema require the python builtin type. In
21+
# this case, we should revert built type to typing type for 2.5.1 backward
22+
# compatibility.
23+
for k, v in op_func.__annotations__.items():
24+
if v == list[int]:
25+
op_func.__annotations__[k] = List[int]
26+
if v == Optional[list[int]]:
27+
op_func.__annotations__[k] = Optional[List[int]]
28+
# TODO: add more type convert here if needed.
29+
import torch.library
30+
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
31+
my_lib = target_lib or vllm_lib
32+
my_lib.define(op_name + schema_str, tags=tags)
33+
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
34+
if fake_impl is not None:
35+
my_lib._register_fake(op_name, fake_impl)
36+
37+
38+
utils.direct_register_custom_op = ascend_direct_register_custom_op

0 commit comments

Comments
 (0)