Skip to content

Perf: load weights, create KV cache, initialize tokenizer in parallel #3215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include <functional>
#include <numeric>
#include <optional>
#include <thread>
#include <tuple>
#include <unordered_set>
#include <vector>

#include "../support/json_parser.h"
#include "../support/result.h"
Expand Down Expand Up @@ -450,24 +452,80 @@ class EngineImpl : public Engine {
"enabled and not implemented with hybrid prefill yet.";
}
}
// - Load model weights, create KV cache and workspace.

// - Concurrently load model weights, create KV cache, and initialize tokenizer.
std::vector<std::thread> workers;
std::exception_ptr load_params_exptr = nullptr;
std::exception_ptr create_kv_cache_exptr = nullptr;
std::exception_ptr tokenizer_exptr = nullptr;

// Create KV Cache threads
for (const Model& model : n->models_) {
workers.emplace_back([&, model, engine_config]() {
try {
NVTXScopedRange nvtx_scope("CreateKVCache");
model->SetMaxNumSequence(engine_config->max_num_sequence);
model->SetPrefillChunkSize(engine_config->prefill_chunk_size);
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
engine_config->max_total_sequence_length,
engine_config->prefill_chunk_size, engine_config->max_history_size);
} catch (...) {
create_kv_cache_exptr = std::current_exception();
}
});
}

// Load Params threads
for (const Model& model : n->models_) {
workers.emplace_back([&, model]() {
try {
NVTXScopedRange nvtx_scope("LoadParams");
model->LoadParams();
} catch (...) {
load_params_exptr = std::current_exception();
}
});
}

// Tokenizer thread
workers.emplace_back([&]() {
try {
NVTXScopedRange nvtx_scope("Load Tokenizer");
n->tokenizer_ =
Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));
} catch (...) {
tokenizer_exptr = std::current_exception();
}
});

// Wait for all tasks to complete
for (auto& worker : workers) {
worker.join();
}

// Check for exceptions
if (create_kv_cache_exptr) {
std::rethrow_exception(create_kv_cache_exptr);
}
if (load_params_exptr) {
std::rethrow_exception(load_params_exptr);
}
if (tokenizer_exptr) {
std::rethrow_exception(tokenizer_exptr);
}

// - Initialize model workspaces (needs KV cache to be ready for potential allocations)
n->model_workspaces_.clear();
for (const Model& model : n->models_) {
model->LoadParams();
model->SetMaxNumSequence(engine_config->max_num_sequence);
model->SetPrefillChunkSize(engine_config->prefill_chunk_size);
model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,
engine_config->max_total_sequence_length,
engine_config->prefill_chunk_size, engine_config->max_history_size);
n->model_workspaces_.push_back(
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
}
// - Initialize tokenizer and grammar
n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));

// - Initialize grammar related components (needs tokenizer)
n->token_table_ = n->tokenizer_->PostProcessedTokenTable();
n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_);
// - Create the logit processor and sampler, and
// the DraftTokenWorkspaceManager for speculative decoding.

// - Create the logit processor, sampler, and DraftTokenWorkspaceManager
int max_num_tokens = engine_config->max_num_sequence;
DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};
if (engine_config->speculative_mode != SpeculativeMode::kDisable) {
Expand Down