Skip to content

Commit 879f69b

Browse files
authored
[Refactor] Remove duplicate ceil_div (#20023)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 7108934 commit 879f69b

File tree

7 files changed

+20
-42
lines changed

7 files changed

+20
-42
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
2020
w8a8_block_fp8_matmul,
2121
)
22-
from vllm.utils import FlexibleArgumentParser
22+
from vllm.utils import FlexibleArgumentParser, cdiv
2323

2424
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
2525
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
@@ -117,14 +117,9 @@ def bench_fp8(
117117
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
118118
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
119119

120-
def ceil_div(x: int, y: int) -> int:
121-
return (x + y - 1) // y
122-
123-
block_scale_a = torch.rand(
124-
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
125-
)
120+
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
126121
block_scale_b = torch.rand(
127-
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
122+
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
128123
)
129124
block_scale_a_M_major = block_scale_a.t().contiguous().t()
130125
block_scale_b_K_major = block_scale_b.t().contiguous().t()

tests/kernels/attention/test_mla_decode_cpu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
import vllm._custom_ops as ops
99
from vllm.platforms import current_platform
10-
11-
12-
def cdiv(a, b):
13-
return (a + b - 1) // b
10+
from vllm.utils import cdiv
1411

1512

1613
def ref_mla(

tests/kernels/attention/test_triton_decode_attention.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
import torch
66

77
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
8-
9-
10-
def cdiv(a, b):
11-
return (a + b - 1) // b
8+
from vllm.utils import cdiv
129

1310

1411
@pytest.mark.parametrize("B", [3, 5])

tests/neuron/1_core/test_prefix_prefill.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
import torch.nn.functional as F
99

10+
from vllm.utils import cdiv
11+
1012

1113
class BlockDiagonalCausalFromBottomRightMask:
1214

@@ -398,11 +400,8 @@ def test_contexted_kv_attention(
398400
assert (large_tile_size >= B_P_SIZE
399401
), f"Expect {large_tile_size=} to be larger than {B_P_SIZE=}"
400402

401-
def ceil_div(a, b):
402-
return (a + b - 1) // b
403-
404403
def pad_to_multiple(a, b):
405-
return ceil_div(a, b) * b
404+
return cdiv(a, b) * b
406405

407406
def pad_to_next_power_of_2(a):
408407
assert a > 0
@@ -411,7 +410,7 @@ def pad_to_next_power_of_2(a):
411410
# calculate input shapes
412411
max_num_queries = pad_to_next_power_of_2(sum(query_lens))
413412
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
414-
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
413+
num_active_blocks = cdiv(context_lens, block_size).sum().item()
415414
num_active_blocks = pad_to_multiple(num_active_blocks,
416415
large_tile_size // block_size)
417416
context_kv_len = num_active_blocks * block_size

vllm/attention/ops/nki_flash_attn.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from neuronxcc import nki
99
from neuronxcc.nki.language import par_dim
1010

11-
12-
def ceil_div(a, b):
13-
return (a + b - 1) // b
11+
from vllm.utils import cdiv
1412

1513

1614
def is_power_of_2(x):
@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
3533
(num_tiles, num_blocks_per_tile))
3634

3735
block_tables_sbuf = nl.zeros(
38-
(ceil_div(num_tiles,
39-
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
36+
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
4037
dtype=nl.int32,
4138
)
42-
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
39+
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
4340
i_p = nl.arange(B_P_SIZE)[:, None]
4441
i_f = nl.arange(num_blocks_per_tile)[None, :]
4542
block_tables_sbuf[i, i_p, i_f] = nl.load(
@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
8380
assert is_power_of_2(
8481
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
8582

86-
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
83+
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
8784
block_tables_transposed = nl.ndarray(
8885
(
8986
num_loads,
@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
165162
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
166163
"""
167164
# load key cache
168-
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
165+
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
169166
for load_idx in nl.affine_range(num_loads):
170167
i_p = nl.arange(B_P_SIZE)[:, None]
171168
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
@@ -605,7 +602,7 @@ def flash_paged_attention(
605602
)
606603

607604
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
608-
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
605+
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
609606
cur_k_tile = nl.ndarray(
610607
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
611608
dtype=kernel_dtype,

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@
66

77
from vllm import _custom_ops as ops
88
from vllm.triton_utils import tl, triton
9-
from vllm.utils import round_up
10-
11-
12-
def ceil_div(a, b):
13-
return (a + b - 1) // b
9+
from vllm.utils import cdiv, round_up
1410

1511

1612
@triton.jit
@@ -115,7 +111,7 @@ def moe_align_block_size_triton(
115111
cumsum = torch.zeros((num_experts + 1, ),
116112
dtype=torch.int32,
117113
device=topk_ids.device)
118-
tokens_per_thread = ceil_div(numel, num_experts)
114+
tokens_per_thread = cdiv(numel, num_experts)
119115

120116
moe_align_block_size_stage1[grid](
121117
topk_ids,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
CUTLASS_BLOCK_FP8_SUPPORTED)
2020
from vllm.platforms import current_platform
2121
from vllm.triton_utils import tl, triton
22-
from vllm.utils import direct_register_custom_op
22+
from vllm.utils import cdiv, direct_register_custom_op
2323

2424
logger = init_logger(__name__)
2525
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@@ -158,12 +158,9 @@ def apply_w8a8_block_fp8_linear(
158158
if current_platform.is_cuda():
159159
if current_platform.has_device_capability(100):
160160

161-
def ceil_div(x: int, y: int) -> int:
162-
return (x + y - 1) // y
163-
164161
use_cutlass = cutlass_block_fp8_supported and (
165-
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
166-
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
162+
cdiv(weight.shape[0], 128) == weight_scale.shape[0]
163+
and cdiv(weight.shape[1], 128) == weight_scale.shape[1])
167164
else:
168165
# TODO: update this after switching to public sm90 block scale gemm
169166
# as it also supports weight.shape % 128 != 0

0 commit comments

Comments
 (0)