diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56f..f15b37d5b663 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,6 +45,7 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--log_output_filename", type=str, default="eagle_output.txt") args = parser.parse_args() model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" @@ -90,6 +91,22 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + # save output text in eagle_output.txt file for quality check + log_output_data = [] + for i, output in enumerate(outputs): + input_text = tokenizer.decode(output.prompt_token_ids) + log_output_data.append({ + "input": input_text, + "output": output.outputs[0].text + }) + + with open("eagle_output.txt", "w") as f: + f.write( + json.dumps(log_output_data, indent=4, ensure_ascii=False)) + print("-" * 50) + print(f"Output texts saved to {args.log_output_filename}") + print("-" * 50) + # calculate the average number of accepted tokens per forward pass, +1 is # to account for the token from the target model that's always going to be # accepted diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 673714980592..409b31f47970 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -44,8 +44,13 @@ def test_prompts(): @pytest.fixture def sampling_config(): - # Only support greedy for now - return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + # return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) + return [ SamplingParams(temperature=0, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.1, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.2, max_tokens=10, ignore_eos=False), + SamplingParams(temperature=0.3, max_tokens=10, ignore_eos=False), + # SamplingParams(temperature=1, top_p=0.75, max_tokens=10, ignore_eos=False), + ] @pytest.fixture @@ -72,7 +77,9 @@ def test_ngram_correctness( m.setenv("VLLM_USE_V1", "1") ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) + ref_outputs = [] + for sampling_param in sampling_config: + ref_outputs.append(ref_llm.chat(test_prompts, sampling_param)) del ref_llm spec_llm = LLM( @@ -85,20 +92,22 @@ def test_ngram_correctness( }, max_model_len=1024, ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + + for i, sampling_param in enumerate(sampling_config): + spec_output = spec_llm.chat(test_prompts, sampling_param) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs[i], spec_output): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs[i])) del spec_llm @@ -115,9 +124,13 @@ def test_eagle_correctness( ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + ref_outputs = [] + spec_outputs = [] ref_llm = LLM(model=model_name, max_model_len=1024) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) + + for sampling_param in sampling_config: + ref_outputs.append(ref_llm.chat(test_prompts, sampling_param)) del ref_llm spec_llm = LLM( @@ -129,18 +142,22 @@ def test_eagle_correctness( }, max_model_len=1024, ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - - # Heuristic: expect at least 70% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.7 * len(ref_outputs)) + + for sampling_param in sampling_config: + spec_outputs.append(spec_llm.chat(test_prompts, sampling_param)) del spec_llm + + for i in range(len(sampling_config)): + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs[i], spec_outputs[i]): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs[i])), "Failed for sampling_param: " + str(sampling_config[i]) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2322463c0713..2cd0d230d0b1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -31,6 +31,39 @@ def __init__( device=device, dtype=torch.int32) + max_batch_size = vllm_config.scheduler_config.max_num_seqs + vocab_size = vllm_config.model_config.get_vocab_size() + + # setup buffers for draft token ids and probs to be + # reused across steps for Rejection Sampling + self.curr_num_tokens = -1 + self.curr_batch_size = -1 + + # packed tensor for [bs, num_speculative_tokens] + self._draft_token_ids_buffer = torch.zeros(max_batch_size, + self.num_speculative_tokens, + dtype=torch.long, + device=device) + + # packed tensor for [num_tokens, vocab_size] + self._draft_token_ids_buffer_shape = self._draft_token_ids_buffer.shape + self._draft_probs_buffer = torch.zeros(max_batch_size * self.num_speculative_tokens, + vocab_size, + # TODO(ekagra): pass dtype + dtype=torch.float32, + device=device) + self._draft_probs_buffer_shape = self._draft_probs_buffer.shape + + def get_draft_token_ids(self) -> torch.Tensor: + # [batch_size, num_speculative_tokens] + assert self.curr_batch_size != -1, "EagleProposer hasn't proposed yet." + return self._draft_token_ids_buffer[:self.curr_batch_size] + + def get_draft_probs(self) -> torch.Tensor: + # [batch_size, num_speculative_tokens, vocab_size] + assert self.curr_num_tokens != -1, "EagleProposer hasn't proposed yet." + return self._draft_probs_buffer[:self.curr_num_tokens] + def propose( self, # [num_tokens] @@ -48,9 +81,33 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ): + # make sure that the buffers size has not changed + # by any future operation + assert self._draft_probs_buffer_shape.numel( + ) == self._draft_probs_buffer.numel( + ), "Size of self._draft_probs_buffer has been changed. " + "Make sure it remains the same." + + assert self._draft_token_ids_buffer_shape.numel( + ) == self._draft_token_ids_buffer.numel( + ), "Size of self._draft_token_ids_buffer has been changed. " + "Make sure it remaiins the same." + + # restore shape of buffers if it has been + # changed by any future operation + if (self._draft_probs_buffer.shape != self._draft_probs_buffer_shape): + self._draft_probs_buffer.reshape(self._draft_probs_buffer_shape) + + if (self._draft_token_ids_buffer.shape + != self._draft_token_ids_buffer_shape): + self._draft_token_ids_buffer.reshape( + self._draft_token_ids_buffer_shape) + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] + self.curr_batch_size = batch_size + self.curr_num_tokens = batch_size * self.num_speculative_tokens last_token_indices = cu_num_tokens[1:] - 1 input_ids = torch.empty_like(target_token_ids) @@ -91,26 +148,27 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + compute_probs_and_sample_next_token(logits, sampling_metadata, + 0, + self.num_speculative_tokens, + batch_size, + self.arange, + self._draft_token_ids_buffer, + self._draft_probs_buffer) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + return # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] - positions = target_positions[last_token_indices] hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + for speculative_token_idx in range(self.num_speculative_tokens - 1): # Update the inputs. - input_ids = draft_token_ids_list[-1] + input_ids = self._draft_token_ids_buffer[:batch_size, speculative_token_idx] positions += 1 attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 @@ -130,16 +188,13 @@ def propose( positions=positions, ) logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) - draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) - - # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + compute_probs_and_sample_next_token(logits, sampling_metadata, + speculative_token_idx + 1, + self.num_speculative_tokens, + batch_size, + self.arange, + self._draft_token_ids_buffer, + self._draft_probs_buffer) @staticmethod def prepare_inputs( @@ -214,18 +269,42 @@ def load_model(self, target_model: nn.Module) -> None: def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: + # index of the speculative token among num_speculative_tokens + speculative_token_idx: int, + # max number of speculative tokens + num_speculative_tokens: int, + # current batch size + batch_size: int, + # [batch_size + 1] + arange: torch.Tensor, + # [batch_size, num_speculative_tokens] + draft_token_ids_buffer: torch.Tensor, + # [batch_size, num_speculative_tokens, vocab_size] + draft_probs_buffer: torch.Tensor, +): + # We pass in the entire preallocated buffers draft_token_ids_buffer + # and draft_probs_buffer and select the portion of the buffer that + # we need to fill in using batch_size and speculative_token_idx. + # This allows us to write in-place. If we passed in the specific + # tensors slices directly to func, i.e., + # draft_token_ids_buffer[:batch_size, speculative_token_idx] + # as draft_token_ids, then draft_token_ids = logits.argmax(dim=-1) + # would create a new tensor and not allow in-place writes. + + draft_probs_buffer_indices = arange[:batch_size] * num_speculative_tokens + speculative_token_idx + if sampling_metadata.all_greedy: # For greedy requests, draft_probs is not used in rejection sampling. # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs + draft_probs_buffer[draft_probs_buffer_indices] = logits.to(dtype=torch.float32) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = logits.argmax(dim=-1) + return is_greedy = sampling_metadata.temperature == -1 temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) + draft_probs_buffer[draft_probs_buffer_indices] = logits.softmax( + dim=-1, dtype=torch.float32) # NOTE(woosuk): Currently, we ignore most of the sampling parameters in # generating the draft tokens. We only use the temperature. While this @@ -233,17 +312,21 @@ def compute_probs_and_sample_next_token( # of the generated tokens after rejection sampling. # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) + q = torch.empty_like(draft_probs_buffer[draft_probs_buffer_indices]) q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ + draft_probs_buffer[draft_probs_buffer_indices] \ + .div_(q) \ + .argmax(dim=-1) \ + .view(-1) if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( + greedy_token_ids = draft_probs_buffer[draft_probs_buffer_indices].argmax(dim=-1) + draft_token_ids_buffer[:batch_size, speculative_token_idx] = \ + torch.where( is_greedy, greedy_token_ids, - next_token_ids, + draft_token_ids_buffer[:batch_size, speculative_token_idx], ) - return next_token_ids, probs @triton.jit diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7e548bb48b57..2d77b81b9878 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -10,6 +10,7 @@ class NgramProposer: def __init__(self, vllm_config: VllmConfig): + self._draft_token_ids = None # Minimum length of the n-gram to match. self.min_n = vllm_config.speculative_config.prompt_lookup_min # Maximum length of the n-gram to match. @@ -22,10 +23,16 @@ def __init__(self, vllm_config: VllmConfig): # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) + def get_draft_token_ids(self): + return self._draft_token_ids + + def get_draft_probs(self): + return None + def propose( self, context_token_ids: np.ndarray, - ) -> Optional[np.ndarray]: + ): """Proposes the next sequence of tokens based on n-gram pattern matching in the context. The function finds matches of the last n tokens in the previous context, and returns k tokens that followed @@ -54,8 +61,9 @@ def propose( for n in range(self.max_n, self.min_n - 1, -1): result = _find_subarray_kmp(context_token_ids, n, self.k) if result is not None: - return result - return None + self._draft_token_ids = result + return + return def load_model(self, *args, **kwargs): # No model to load. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0e70d77e1b7e..5597c999c8c3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1108,7 +1108,7 @@ def execute_model( target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( spec_decode_metadata, - None, # draft_probs + self.drafter.get_draft_probs(), target_logits, bonus_token_ids, sampling_metadata, @@ -1220,7 +1220,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1230,10 +1230,7 @@ def execute_model( block_table=attn_metadata.block_table, sampling_metadata=sampling_metadata, ) - spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs + spec_token_ids = self.drafter.get_draft_token_ids().tolist() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1268,8 +1265,8 @@ def generate_draft_token_ids( start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :end_idx]) + self.drafter.propose(self.input_batch.token_ids_cpu[i, :end_idx]) + drafter_output = self.drafter.get_draft_token_ids() if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: