@@ -384,46 +384,3 @@ def prepare_eagle_input_sequential(out_tensor: torch.Tensor,
384
384
(target_indices < end_pos ) & \
385
385
(offset_tensor < num_tokens )
386
386
out_tensor [target_indices [mask ]] = values_to_store [mask ]
387
-
388
-
389
- # NOTE(woosuk): Currently, the below code is not used and we always use argmax
390
- # to sample the draft tokens. We will use this after we find a way to manage
391
- # the draft prob tensor.
392
- # Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
393
- # FIXME(woosuk): The logic here is duplicated with the main sampling code.
394
- # We should refactor this to reuse the same sampling implementation.
395
- def compute_probs_and_sample_next_token (
396
- logits : torch .Tensor ,
397
- sampling_metadata : SamplingMetadata ,
398
- ) -> tuple [torch .Tensor , torch .Tensor ]:
399
- if sampling_metadata .all_greedy :
400
- # For greedy requests, draft_probs is not used in rejection sampling.
401
- # Therefore, we can just return the logits.
402
- probs = logits
403
- next_token_ids = logits .argmax (dim = - 1 )
404
- return next_token_ids , probs
405
-
406
- is_greedy = sampling_metadata .temperature == - 1
407
- temperature = torch .where (is_greedy , 1.0 , sampling_metadata .temperature )
408
- logits .div_ (temperature .view (- 1 , 1 ))
409
- probs = logits .softmax (dim = - 1 , dtype = torch .float32 )
410
-
411
- # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
412
- # generating the draft tokens. We only use the temperature. While this
413
- # could degrade the acceptance rate, it does not affect the distribution
414
- # of the generated tokens after rejection sampling.
415
-
416
- # TODO(woosuk): Consider seeds.
417
- q = torch .empty_like (probs )
418
- q .exponential_ ()
419
- # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
420
- # will be used later for rejection sampling.
421
- next_token_ids = probs .div (q ).argmax (dim = - 1 ).view (- 1 )
422
- if not sampling_metadata .all_random :
423
- greedy_token_ids = probs .argmax (dim = - 1 )
424
- next_token_ids = torch .where (
425
- is_greedy ,
426
- greedy_token_ids ,
427
- next_token_ids ,
428
- )
429
- return next_token_ids , probs
0 commit comments