Skip to content

Commit b2eb2b5

Browse files
authored
[Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0 (#19346)
Signed-off-by: rzou <zou3519@gmail.com>
1 parent 21274ab commit b2eb2b5

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

csrc/torch_bindings.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2020
// vLLM custom ops
2121
//
2222

23-
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
23+
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
24+
// so we need
2425
// to override this for many GEMMs with the following tag. Otherwise,
2526
// torch.compile will force all input tensors to be contiguous(), which
2627
// will break many custom ops that require column-major weight matrices.
27-
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
28-
// to match exact eager-mode strides.
29-
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
28+
// This was a bug and PyTorch 2.7 has since fixed this.
29+
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
30+
#define stride_tag at::Tag::needs_fixed_stride_order
31+
#else
32+
#define stride_tag
33+
#endif
3034

3135
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
3236
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.platforms import current_platform
9-
from vllm.utils import direct_register_custom_op
9+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
1010

1111

1212
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
9393

9494

9595
if current_platform.is_rocm():
96+
if is_torch_equal_or_newer("2.7.0"):
97+
tags = ()
98+
else:
99+
tags = (torch.Tag.needs_fixed_stride_order, ),
96100
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
97101
op_func=mla_decode_fwd_impl,
98102
mutates_args=["o"],
99103
fake_impl=mla_decode_fwd_fake,
100-
tags=[torch.Tag.needs_fixed_stride_order])
104+
tags=tags)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
dequant_mxfp4)
3434
from vllm.platforms import current_platform
3535
from vllm.triton_utils import tl, triton
36-
from vllm.utils import direct_register_custom_op
36+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
3737
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
3838

3939
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@@ -1056,7 +1056,8 @@ def inplace_fused_experts_fake(
10561056
op_func=inplace_fused_experts,
10571057
mutates_args=["hidden_states"],
10581058
fake_impl=inplace_fused_experts_fake,
1059-
tags=(torch.Tag.needs_fixed_stride_order, ),
1059+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1060+
(torch.Tag.needs_fixed_stride_order, )),
10601061
)
10611062

10621063

@@ -1122,7 +1123,8 @@ def outplace_fused_experts_fake(
11221123
op_func=outplace_fused_experts,
11231124
mutates_args=[],
11241125
fake_impl=outplace_fused_experts_fake,
1125-
tags=(torch.Tag.needs_fixed_stride_order, ),
1126+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1127+
(torch.Tag.needs_fixed_stride_order, )),
11261128
)
11271129

11281130

0 commit comments

Comments
 (0)