Skip to content

Commit dcece51

Browse files
authored
[Serving] Apply tree structure in draft token verification (#2563)
This adds the interface to draft token state and sampler to allow tree structure being recorded and used for verification
1 parent 873827c commit dcece51

11 files changed

+58
-39
lines changed

cpp/serve/engine_actions/batch_draft.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ class BatchDraftActionObj : public EngineActionObj {
142142
models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_,
143143
&model_workspaces_[0].draft_probs_storage);
144144
for (int i = 0; i < num_rsentries; ++i) {
145-
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
145+
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
146+
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
146147
}
147148

148149
auto tdraft_end = std::chrono::high_resolution_clock::now();

cpp/serve/engine_actions/batch_verify.cc

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class BatchVerifyActionObj : public EngineActionObj {
6565
Array<GenerationConfig> generation_cfg;
6666
std::vector<RandomGenerator*> rngs;
6767
std::vector<std::vector<SampleResult>> draft_output_tokens;
68+
std::vector<int64_t> token_tree_parent_ptr;
69+
token_tree_parent_ptr.reserve(total_verify_length);
6870
request_internal_ids.reserve(num_rsentries);
6971
all_tokens_to_verify.reserve(total_verify_length);
7072
verify_request_mstates.reserve(num_rsentries);
@@ -83,9 +85,11 @@ class BatchVerifyActionObj : public EngineActionObj {
8385
// the last committed token + all the draft tokens.
8486
draft_token_slots_.push_back(0); // placeholder for the last committed token
8587
all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());
88+
token_tree_parent_ptr.push_back(-1);
8689
for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {
8790
all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());
8891
draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);
92+
token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);
8993
}
9094
verify_request_mstates.push_back(verify_mstate);
9195
generation_cfg.push_back(rsentries[i]->request->generation_cfg);
@@ -101,16 +105,6 @@ class BatchVerifyActionObj : public EngineActionObj {
101105
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
102106
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");
103107

104-
// Construct the token tree. Right now only chains are supported.
105-
std::vector<int64_t> token_tree_parent_ptr;
106-
token_tree_parent_ptr.reserve(total_verify_length);
107-
for (int i = 0; i < num_rsentries; ++i) {
108-
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
109-
token_tree_parent_ptr.push_back(pos - 1);
110-
}
111-
}
112-
ICHECK_EQ(token_tree_parent_ptr.size(), total_verify_length);
113-
114108
RECORD_EVENT(trace_recorder_, request_ids, "start verify");
115109
NDArray logits = models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids,
116110
verify_lengths, token_tree_parent_ptr);
@@ -140,7 +134,7 @@ class BatchVerifyActionObj : public EngineActionObj {
140134
std::vector<std::vector<SampleResult>> sample_results_arr =
141135
sampler_->BatchVerifyDraftTokensWithProbAfterTopP(
142136
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
143-
draft_output_tokens, draft_probs_on_device);
137+
draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
144138
ICHECK_EQ(sample_results_arr.size(), num_rsentries);
145139

146140
// We collect the requests whose drafts are fully accepted.

cpp/serve/engine_actions/eagle_batch_draft.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ class EagleBatchDraftActionObj : public EngineActionObj {
160160
&model_workspaces_[0].draft_probs_storage);
161161
// No need to save hidden states as they are not used by subsequent engine actions
162162
for (int i = 0; i < num_rsentries; ++i) {
163-
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
163+
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
164+
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
164165
}
165166

166167
auto tdraft_end = std::chrono::high_resolution_clock::now();

cpp/serve/engine_actions/eagle_batch_verify.cc

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
6565
Array<GenerationConfig> generation_cfg;
6666
std::vector<RandomGenerator*> rngs;
6767
std::vector<std::vector<SampleResult>> draft_output_tokens;
68+
std::vector<int64_t> token_tree_parent_ptr;
6869
request_internal_ids.reserve(num_rsentries);
6970
all_tokens_to_verify.reserve(total_draft_length);
71+
token_tree_parent_ptr.reserve(total_draft_length);
7072
verify_request_mstates.reserve(num_rsentries);
7173
rngs.reserve(num_rsentries);
7274
generation_cfg.reserve(num_rsentries);
@@ -83,9 +85,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
8385
// the last committed token + all the draft tokens but the last one.
8486
all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());
8587
draft_token_slots_.push_back(0); // placeholder for the last committed token
88+
token_tree_parent_ptr.push_back(-1);
89+
8690
for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {
8791
all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());
8892
draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);
93+
token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);
8994
}
9095
verify_request_mstates.push_back(verify_mstate);
9196
generation_cfg.push_back(rsentries[i]->request->generation_cfg);
@@ -111,16 +116,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
111116
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
112117
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");
113118

114-
// Construct the token tree. Right now only chains are supported.
115-
std::vector<int64_t> token_tree_parent_ptr;
116-
token_tree_parent_ptr.reserve(cum_verify_lengths.back());
117-
for (int i = 0; i < num_rsentries; ++i) {
118-
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
119-
token_tree_parent_ptr.push_back(pos - 1);
120-
}
121-
}
122-
ICHECK_EQ(token_tree_parent_ptr.size(), cum_verify_lengths.back());
123-
124119
RECORD_EVENT(trace_recorder_, request_ids, "start verify");
125120
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
126121
embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
@@ -143,7 +138,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
143138
std::vector<std::vector<SampleResult>> sample_results_arr =
144139
sampler_->BatchVerifyDraftTokensWithProbAfterTopP(
145140
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
146-
draft_output_tokens, draft_probs_on_device);
141+
draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
147142
ICHECK_EQ(sample_results_arr.size(), num_rsentries);
148143

149144
// We collect the requests whose drafts are fully accepted.
@@ -398,7 +393,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
398393
&model_workspaces_[0].draft_hidden_states_storage);
399394
}
400395
for (int i = 0; i < static_cast<int>(mstates.size()); ++i) {
401-
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
396+
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
397+
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
402398
}
403399
}
404400
/*!

cpp/serve/engine_actions/eagle_new_request_prefill.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,8 +355,12 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
355355
&model_workspaces_[0].draft_hidden_states_storage);
356356
}
357357
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
358+
int parent_idx =
359+
rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.empty()
360+
? -1
361+
: rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.size() - 1;
358362
rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(
359-
sample_results[i], draft_token_slots_[sample_indices[i]]);
363+
sample_results[i], draft_token_slots_[sample_indices[i]], parent_idx);
360364
}
361365
}
362366

cpp/serve/logit_processor.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,9 @@ class LogitProcessorImpl : public LogitProcessorObj {
299299
p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty;
300300
++num_token_for_penalty;
301301
if (j > 0) {
302-
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1);
302+
// Assume chain-style token tree.
303+
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1,
304+
j - 1 - 1);
303305
}
304306
}
305307
if (num_token_to_process != 1) {
@@ -379,7 +381,8 @@ class LogitProcessorImpl : public LogitProcessorObj {
379381
p_seq_ids[token_start_offset + j] = 1;
380382
}
381383
if (j > 0) {
382-
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1);
384+
// Assume chain-style token tree.
385+
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1, j - 1 - 1);
383386
}
384387
}
385388
if (token_number != 1) {

cpp/serve/request_state.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,19 @@ void RequestModelStateNode::RollbackTokens(int count) {
7575
}
7676
}
7777

78-
void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot) {
78+
void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot,
79+
int64_t parent_idx) {
7980
draft_output_tokens.push_back(std::move(sampled_token));
8081
draft_token_slots.push_back(draft_token_slot);
82+
draft_token_parent_idx.push_back(parent_idx);
8183
appeared_token_ids[sampled_token.GetTokenId()] += 1;
8284
}
8385

8486
void RequestModelStateNode::RemoveLastDraftToken() {
8587
ICHECK(!draft_output_tokens.empty());
8688
auto it = appeared_token_ids.find(draft_output_tokens.back().GetTokenId());
8789
draft_output_tokens.pop_back();
90+
draft_token_parent_idx.pop_back();
8891
CHECK(it != appeared_token_ids.end());
8992
if (--it->second == 0) {
9093
appeared_token_ids.erase(it);

cpp/serve/request_state.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class RequestModelStateNode : public Object {
7676
std::vector<SampleResult> draft_output_tokens;
7777
/*! \brief The storage slots for the associated states of draft tokens. */
7878
std::vector<int> draft_token_slots;
79+
/*! \brief The parent indices of the draft tokens. */
80+
std::vector<int64_t> draft_token_parent_idx;
7981
/*! \brief The appeared committed and draft tokens and their occurrence times. */
8082
std::unordered_map<int32_t, int32_t> appeared_token_ids;
8183

@@ -106,7 +108,7 @@ class RequestModelStateNode : public Object {
106108
void RollbackTokens(int count);
107109

108110
/*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */
109-
void AddDraftToken(SampleResult sampled_token, int draft_token_slot);
111+
void AddDraftToken(SampleResult sampled_token, int draft_token_slot, int64_t parent_idx);
110112
/*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */
111113
void RemoveAllDraftTokens(std::vector<int>* removed_draft_token_slots = nullptr);
112114

cpp/serve/sampler/cpu_sampler.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ class CPUSampler : public SamplerObj {
413413
const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,
414414
const std::vector<RandomGenerator*>& rngs,
415415
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
416-
NDArray draft_probs_on_device) final {
416+
const std::vector<int64_t>& token_tree_parent_ptr, NDArray draft_probs_on_device) final {
417417
// probs_on_host: (n, v)
418418
RECORD_EVENT(trace_recorder_, request_ids, "start draft verification");
419419
CHECK_EQ(probs_on_host->ndim, 2);
@@ -435,6 +435,12 @@ class CPUSampler : public SamplerObj {
435435
int verify_start = cum_verify_lengths[i];
436436
int verify_end = cum_verify_lengths[i + 1];
437437

438+
CHECK_EQ(token_tree_parent_ptr[verify_start], -1);
439+
for (int j = verify_start + 1; j < verify_end; ++j) {
440+
CHECK_EQ(token_tree_parent_ptr[j], j - verify_start)
441+
<< "CPU sampler only supports chain-style draft tokens.";
442+
}
443+
438444
int cur_token_idx = 0;
439445
// Sub 1 to ignore the last prediction.
440446
for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) {

cpp/serve/sampler/gpu_sampler.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ class GPUSampler : public SamplerObj {
203203
const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,
204204
const std::vector<RandomGenerator*>& rngs,
205205
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
206-
NDArray draft_probs_on_device) final {
206+
const std::vector<int64_t>& token_tree_parent_ptr, NDArray draft_probs_on_device) final {
207207
NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP");
208208
std::vector<std::vector<SampleResult>> sample_results;
209209
// probs_on_device: (n, v)
@@ -252,21 +252,29 @@ class GPUSampler : public SamplerObj {
252252
token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_);
253253
std::vector<int> token_tree_child_to_parent(/*n=*/num_nodes);
254254

255+
int* token_tree_first_child_ptr_host = static_cast<int*>(token_tree_first_child_host->data);
256+
int* token_tree_next_sibling_ptr_host = static_cast<int*>(token_tree_next_sibling_host->data);
255257
// Build the tree structure on CPU
256258
for (int i = 0; i < num_sequence; i++) {
257259
// Assuming no tree structure for now
258260
int start = cum_verify_lengths[i];
259261
int end = cum_verify_lengths[i + 1];
260262
ICHECK_GE(end - start, 2);
261-
token_tree_child_to_parent[start] = -1; // root has no parent
262263
for (int j = 0; j < end - start; j++) {
263264
int cur_node = j + start;
264-
int child_node = j + 1 >= end - start ? -1 : cur_node + 1;
265-
static_cast<int*>(token_tree_first_child_host->data)[cur_node] = child_node;
266-
if (child_node != -1) {
267-
token_tree_child_to_parent[child_node] = cur_node;
265+
int parent_node =
266+
token_tree_parent_ptr[cur_node] != -1 ? token_tree_parent_ptr[cur_node] + start : -1;
267+
token_tree_first_child_ptr_host[cur_node] = -1;
268+
if (parent_node != -1 && token_tree_first_child_ptr_host[parent_node] == -1) {
269+
token_tree_first_child_ptr_host[parent_node] = cur_node;
270+
}
271+
token_tree_child_to_parent[cur_node] = parent_node;
272+
if (cur_node + 1 < end && token_tree_parent_ptr[cur_node - start + 1] ==
273+
token_tree_parent_ptr[cur_node - start]) {
274+
token_tree_next_sibling_ptr_host[cur_node] = cur_node + 1;
275+
} else {
276+
token_tree_next_sibling_ptr_host[cur_node] = -1;
268277
}
269-
static_cast<int*>(token_tree_next_sibling_host->data)[cur_node] = -1;
270278
}
271279
static_cast<int*>(token_tree_parent_ptr_host->data)[i] = start; // point to the root
272280
}

0 commit comments

Comments
 (0)