Skip to content

Commit ebed81f

Browse files
elaineyzsssrijan-amazonAakashShetty-aws
authored
Update default neuron config for speculation (#18274)
Signed-off-by: Elaine Zhao <elaineyz@amazon.com> Co-authored-by: Shashwat Srijan <sssrijan@amazon.com> Co-authored-by: Aakash Shetty <sheaak@amazon.com>
1 parent e2d7d31 commit ebed81f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

vllm/model_executor/model_loader/neuronx_distributed.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
502502
max_context_length=scheduler_config.max_model_len,
503503
seq_len=scheduler_config.max_model_len,
504504
enable_bucketing=True,
505-
is_continuous_batching=(batch_size > 1),
505+
is_continuous_batching=True,
506506
quantized=False,
507507
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
508508
padding_side="right",
@@ -520,13 +520,15 @@ def _get_default_speculation_config(model_config: ModelConfig,
520520
args."""
521521
neuron_config = dict(
522522
tp_degree=parallel_config.tensor_parallel_size,
523+
ctx_batch_size=1,
523524
batch_size=scheduler_config.max_num_seqs,
524525
max_context_length=scheduler_config.max_model_len,
525526
seq_len=scheduler_config.max_model_len,
526527
speculation_length=speculation_config.num_speculative_tokens,
527528
trace_tokengen_model=False,
528529
enable_fused_speculation=True,
529530
enable_bucketing=True,
531+
is_continuous_batching=True,
530532
quantized=False,
531533
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
532534
on_device_sampling_config=dict(

0 commit comments

Comments
 (0)