Skip to content

Commit 0c24ad3

Browse files
committed
triton cdiv
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent 0fd1f87 commit 0c24ad3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.model_executor.model_loader import get_model
1212
from vllm.model_executor.models import supports_multimodal
1313
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
14+
from vllm.triton_utils import triton
1415
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1516
FlashAttentionMetadata)
1617
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -238,8 +239,8 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
238239
hidden_states: torch.Tensor,
239240
attn_metadata: FlashAttentionMetadata,
240241
batch_size: int):
241-
grid = lambda meta: (
242-
(batch_size + meta['BLOCK_SIZE']) // meta['BLOCK_SIZE'], )
242+
# Calculate number of thread blocks
243+
grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), )
243244
attn_metadata.slot_mapping = torch.empty_like(positions)
244245
advance_state_kernel[grid](
245246
# === Input tensors ===

0 commit comments

Comments
 (0)