@@ -502,7 +502,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
502
502
max_context_length = scheduler_config .max_model_len ,
503
503
seq_len = scheduler_config .max_model_len ,
504
504
enable_bucketing = True ,
505
- is_continuous_batching = ( batch_size > 1 ) ,
505
+ is_continuous_batching = True ,
506
506
quantized = False ,
507
507
torch_dtype = TORCH_DTYPE_TO_NEURON_AMP [model_config .dtype ],
508
508
padding_side = "right" ,
@@ -520,13 +520,15 @@ def _get_default_speculation_config(model_config: ModelConfig,
520
520
args."""
521
521
neuron_config = dict (
522
522
tp_degree = parallel_config .tensor_parallel_size ,
523
+ ctx_batch_size = 1 ,
523
524
batch_size = scheduler_config .max_num_seqs ,
524
525
max_context_length = scheduler_config .max_model_len ,
525
526
seq_len = scheduler_config .max_model_len ,
526
527
speculation_length = speculation_config .num_speculative_tokens ,
527
528
trace_tokengen_model = False ,
528
529
enable_fused_speculation = True ,
529
530
enable_bucketing = True ,
531
+ is_continuous_batching = True ,
530
532
quantized = False ,
531
533
torch_dtype = TORCH_DTYPE_TO_NEURON_AMP [model_config .dtype ],
532
534
on_device_sampling_config = dict (
0 commit comments