Skip to content

Commit d0c5ac5

Browse files
authored
Merge pull request #52 from janhq/fix/mistral-template
fix: support Mistral v0.3
2 parents c7703f1 + 9303456 commit d0c5ac5

File tree

3 files changed

+157
-77
lines changed

3 files changed

+157
-77
lines changed

cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ struct LoadModelRequest {
88
int ctx_len = 2048;
99
int n_parallel = 1;
1010
std::string model_path;
11-
std::string user_prompt = "<|im_end|>\n<|im_start|>user\n";
12-
std::string ai_prompt = "<|im_end|>\n<|im_start|>user\n";
13-
std::string system_prompt = "<|im_end|>\n<|im_start|>user\n";
11+
std::string user_prompt = "";
12+
std::string ai_prompt = "";
13+
std::string system_prompt = "";
1414
};
1515

1616
inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
@@ -19,9 +19,9 @@ inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
1919
request.ctx_len = json_body->get("ctx_len", 2048).asInt();
2020
request.n_parallel = json_body->get("n_parallel", 1).asInt();
2121
request.model_path = json_body->get("model_path", "").asString();
22-
request.user_prompt = json_body->get("user_prompt", "<|im_end|>\n<|im_start|>user\n").asString();
23-
request.ai_prompt = json_body->get("ai_prompt", "<|im_end|>\n<|im_start|>assistant\n").asString();
24-
request.system_prompt = json_body->get("system_prompt", "<|im_start|>system\n").asString();
22+
request.user_prompt = json_body->get("user_prompt", "").asString();
23+
request.ai_prompt = json_body->get("ai_prompt", "").asString();
24+
request.system_prompt = json_body->get("system_prompt", "").asString();
2525
}
2626
return request;
2727
}

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc

Lines changed: 124 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,66 @@
2121
using json = nlohmann::json;
2222
using namespace tensorrtllm;
2323

24+
namespace {
25+
constexpr const int k200OK = 200;
26+
constexpr const int k400BadRequest = 400;
27+
constexpr const int k409Conflict = 409;
28+
constexpr const int k500InternalServerError = 500;
29+
30+
// https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31+
// stopWordsList
32+
// 'im', '_' , 'end', '</s>', '<|im_end|>'
33+
const std::vector<int32_t> kOpenhermesStopWords = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1};
34+
const std::string kOhUserPrompt = "<|im_end|>\n<|im_start|>user\n";
35+
const std::string kOhAiPrompt = "<|im_end|>\n<|im_start|>assistant\n";
36+
const std::string kOhSystemPrompt = "<|im_start|>system\n";
37+
const std::unordered_map<std::string, int> kOpenhermesTemplate = {{"<|im_end|>", 32000} , {"<|im_start|>", 32001}};
38+
39+
// '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40+
const std::vector<int32_t> kMistral_V0_3_StopWords
41+
= {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1};
42+
43+
enum class MistralTemplate: int32_t {
44+
kBos = 1,
45+
kEos = 2,
46+
kBeginInst = 3,
47+
kEndInst = 4
48+
};
2449

25-
constexpr const int k200OK = 200;
26-
constexpr const int k400BadRequest = 400;
27-
constexpr const int k409Conflict = 409;
28-
constexpr const int k500InternalServerError = 500;
29-
50+
// TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
51+
bool IsOpenhermes(const std::string& s) {
52+
if (s.find("mistral") != std::string::npos || s.find("Mistral") != std::string::npos) {
53+
return false;
54+
}
55+
return true;
56+
}
57+
}
3058
TensorrtllmEngine::~TensorrtllmEngine() {}
3159

3260
void RemoveId(std::vector<int>& vec, int id) {
3361
vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end());
3462
}
3563

36-
bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state) {
37-
if (infer_state->IsComplete()) {
64+
bool HandleMatch(std::string const& rew_text,
65+
std::shared_ptr<InferenceState> infer_state,
66+
std::function<void(Json::Value&&, Json::Value&&)> cb,
67+
bool is_openhermes) {
68+
if (infer_state->IsComplete(is_openhermes)) {
3869
return false;
3970
}
4071
if (infer_state->stop_word_match_len == 0) {
41-
if (rew_text.find('<') != std::string::npos) { // Found "<" anywhere in the text
72+
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
73+
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
4274
infer_state->stop_word_match_len++; // Move to next state
43-
infer_state->prev_text = rew_text;
4475
return true;
4576
}
46-
}
47-
else if (rew_text == infer_state->sequence[infer_state->stop_word_match_len]) {
77+
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
4878
infer_state->stop_word_match_len++; // Move to next state
49-
infer_state->prev_text = rew_text;
5079
return true;
51-
}
52-
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence[0]) {
80+
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
5381
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
54-
infer_state->prev_text = rew_text;
5582
return true;
56-
}
57-
else {
83+
} else {
5884
infer_state->Reset();
5985
return false; // Reset to start if sequence breaks
6086
}
@@ -67,19 +93,21 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
6793
}
6894

6995
GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList() {
70-
std::vector<int32_t> stop_words_tokens
71-
= {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; // Extend with -1 for increased length
72-
return gpt_session->getBufferManager().copyFrom(stop_words_tokens, ITensor::makeShape({1, 2, 5}), MemoryType::kGPU);
96+
if(is_openhermes_) {
97+
return gpt_session->getBufferManager().copyFrom(kOpenhermesStopWords, ITensor::makeShape({1, 2, static_cast<int>(kOpenhermesStopWords.size()/2)}), MemoryType::kGPU);
98+
} else {
99+
return gpt_session->getBufferManager().copyFrom(kMistral_V0_3_StopWords, ITensor::makeShape({1, 2, static_cast<int>(kMistral_V0_3_StopWords.size()/2)}), MemoryType::kGPU);
100+
}
73101
}
74102

75103
GenerationInput TensorrtllmEngine::CreateGenerationInput(std::vector<int32_t> input_ids_host) {
76104
int input_len = input_ids_host.size();
77-
std::vector<int32_t> input_lengths_host(batchSize, input_len);
105+
std::vector<int32_t> input_lengths_host(batch_size_, input_len);
78106
GenerationInput::TensorPtr input_lengths
79-
= gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batchSize}), MemoryType::kGPU);
107+
= gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batch_size_}), MemoryType::kGPU);
80108
GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager().copyFrom(
81-
input_ids_host, ITensor::makeShape({batchSize, input_len}), MemoryType::kGPU);
82-
GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config->usePackedInput()};
109+
input_ids_host, ITensor::makeShape({batch_size_, input_len}), MemoryType::kGPU);
110+
GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config_->usePackedInput()};
83111
generation_input.stopWordsList = GetTensorChatMLStopWordList();
84112

85113
LOG_INFO << "Create generation input successfully";
@@ -102,27 +130,34 @@ void InferenceThread(
102130
TensorrtllmEngine* self,
103131
SamplingConfig sampling_config,
104132
int input_len,
105-
int outputLen) {
133+
int outputLen, bool is_openhermes) {
106134

107135
// Input preparation
108136
LOG_INFO << "Inference thread started";
109137
GenerationInput generation_input = self->CreateGenerationInput(input_ids_host);
110138
GenerationOutput generation_output = self->CreateGenerationOutput();
111139

112140
// Define the callback to stream each generated token
113-
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
141+
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes](
114142
GenerationOutput::TensorPtr const& output_ids, SizeType step, bool finished) {
115-
LOG_INFO << "Generating tokenizer in thread";
143+
// LOG_INFO << "Generating tokenizer in thread";
116144
// Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
117145
int output_length = output_ids->getShape().d[2]; // Get the length of output IDs based on the tensor shape
118146
// Copy output IDs from GPU to host for printing
119147
std::vector<int32_t> output_idsHost(output_length);
120148
self->gpt_session->getBufferManager().copy(*output_ids, output_idsHost.data(), MemoryType::kCPU);
121149
// Find the last non-zero value in the output IDs starting from the end of the input sequence
122150
std::vector<int> output_idsHostDecode(output_idsHost.begin() + input_len, output_idsHost.end());
151+
123152
RemoveId(output_idsHostDecode, 0);
124-
RemoveId(output_idsHostDecode, 32000);
125-
RemoveId(output_idsHostDecode, 32001);
153+
if(is_openhermes) {
154+
for(auto const& [_, v]: kOpenhermesTemplate) {
155+
RemoveId(output_idsHostDecode, v);
156+
}
157+
} else {
158+
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kBeginInst));
159+
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kEndInst));
160+
}
126161
std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode);
127162

128163
if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size()) {
@@ -192,29 +227,47 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function<void(Json::Value&&, Json:
192227

193228
void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_body, std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
194229
inferences::ChatCompletionRequest request = inferences::fromJson(json_body);
195-
std::string formatted_input = pre_prompt;
230+
std::string formatted_input = pre_prompt_;
196231
nlohmann::json data;
197232
// data["stream"] = completion.stream;
198233
// data["n_predict"] = completion.max_tokens;
199234
data["presence_penalty"] = request.presence_penalty;
200235
Json::Value const& messages = request.messages;
201236

237+
// tokens for Mistral v0.3
238+
// TODO(sang): too much hard code here, need to refactor it soon
239+
std::vector<int32_t> tokens = {static_cast<int32_t>(MistralTemplate::kBos)};
240+
202241
// Format the input from user
242+
int msg_count = 0;
203243
for (auto const& message : messages) {
204244
std::string input_role = message["role"].asString();
205245
std::string role;
206246
if (input_role == "user") {
207-
role = user_prompt;
247+
role = user_prompt_;
208248
std::string content = message["content"].asString();
209249
formatted_input += role + content;
250+
if(!is_openhermes_) {
251+
auto new_tokens = cortex_tokenizer->Encode(content);
252+
new_tokens.insert(new_tokens.begin(), static_cast<int32_t>(MistralTemplate::kBeginInst));
253+
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEndInst));
254+
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
255+
}
210256
}
211257
else if (input_role == "assistant") {
212-
role = ai_prompt;
258+
role = ai_prompt_;
213259
std::string content = message["content"].asString();
214260
formatted_input += role + content;
261+
if(!is_openhermes_) {
262+
auto new_tokens = cortex_tokenizer->Encode(content);
263+
if(msg_count == messages.size() - 1) {
264+
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEos));
265+
}
266+
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
267+
}
215268
}
216269
else if (input_role == "system") {
217-
role = system_prompt;
270+
role = system_prompt_;
218271
std::string content = message["content"].asString();
219272
formatted_input = role + content + formatted_input;
220273
}
@@ -223,13 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
223276
std::string content = message["content"].asString();
224277
formatted_input += role + content;
225278
}
279+
msg_count++;
226280
}
227-
formatted_input += ai_prompt;
281+
formatted_input += ai_prompt_;
282+
// LOG_INFO << formatted_input;
228283
// Format the input from user
229284

230285
std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
231286

232-
std::vector<int32_t> input_ids_host = cortex_tokenizer->Encode(formatted_input);
287+
std::vector<int32_t> input_ids_host;
288+
if(is_openhermes_) {
289+
input_ids_host = cortex_tokenizer->Encode(formatted_input);
290+
} else {
291+
input_ids_host = tokens;
292+
}
293+
233294
int const input_len = input_ids_host.size();
234295
int const outputLen = request.max_tokens - input_len;
235296

@@ -243,23 +304,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
243304
sampling_config.repetitionPenalty = std::vector{request.frequency_penalty};
244305
// Input preparation
245306

246-
std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen);
307+
std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen, is_openhermes_);
247308
inference_thread.detach(); // Detach the thread to allow it to run independently
248309

249-
q_->runTaskInQueue([cb = std::move(callback), infer_state]() {
310+
q_->runTaskInQueue([this, cb = std::move(callback), infer_state]() {
311+
// std::string res_str;
250312
LOG_INFO << "Preparing to run inference task queue...";
251313
while (true) { // Continuously check if the queue is not empty
252314
std::unique_lock<std::mutex> lock(infer_state->queue_mutex); // Lock the queue for exclusive access
253315
if (!infer_state->texts_to_stream.empty()) {
254316
std::string rew_text = infer_state->texts_to_stream.front();
317+
// res_str += rew_text;
255318
infer_state->texts_to_stream.pop();
256-
if (HandleMatch(rew_text, infer_state) && rew_text != "[DONE]") {
319+
if (HandleMatch(rew_text, infer_state, cb, is_openhermes_) && rew_text != "[DONE]") {
257320
continue;
258321
};
259322

260323
if (rew_text == "[DONE]") {
261324
const std::string str
262-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", "", "stop")
325+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, "", "stop")
263326
+ "\n\n" + "data: [DONE]" + "\n\n";
264327

265328
infer_state->is_finished = true;
@@ -275,10 +338,10 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
275338
break;
276339
}
277340
const std::string text_to_stream
278-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", rew_text) + "\n\n";
341+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n";
279342

280343
lock.unlock(); // Unlock as soon as possible
281-
infer_state->prev_text = rew_text;
344+
// std::cout << rew_text;
282345

283346
Json::Value resp_data;
284347
resp_data["data"] = text_to_stream;
@@ -293,6 +356,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
293356
lock.unlock();
294357
}
295358
}
359+
// LOG_INFO << res_str;
296360
});
297361

298362
LOG_INFO << "Inference completed";
@@ -302,16 +366,20 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
302366
void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
303367
model::LoadModelRequest request = model::fromJson(json_body);
304368
std::filesystem::path model_dir = request.model_path;
369+
is_openhermes_ = IsOpenhermes(request.model_path);
305370

306371
int ctx_len = request.ctx_len;
307-
this->user_prompt = request.user_prompt;
308-
this->ai_prompt = request.ai_prompt;
309-
this->system_prompt = request.system_prompt;
310-
this->model_id_ = GetModelId(*json_body);
372+
// We only support 2 models for now, it is ugly but it works :(
373+
if(is_openhermes_) {
374+
user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt;
375+
ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt;
376+
system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt;
377+
}
378+
model_id_ = GetModelId(*json_body);
311379

312-
logger = std::make_shared<TllmLogger>();
313-
logger->setLevel(nvinfer1::ILogger::Severity::kINFO);
314-
initTrtLlmPlugins(logger.get());
380+
logger_ = std::make_shared<TllmLogger>();
381+
logger_->setLevel(nvinfer1::ILogger::Severity::kINFO);
382+
initTrtLlmPlugins(logger_.get());
315383

316384
std::filesystem::path tokenizer_model_name = model_dir / "tokenizer.model";
317385
cortex_tokenizer = std::make_unique<Tokenizer>(tokenizer_model_name.string());
@@ -320,20 +388,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
320388
std::filesystem::path json_file_name = model_dir / "config.json";
321389
auto json = GptJsonConfig::parse(json_file_name);
322390
auto config = json.getModelConfig();
323-
model_config = std::make_unique<GptModelConfig>(config);
391+
model_config_ = std::make_unique<GptModelConfig>(config);
324392
auto world_config = WorldConfig::mpi(1, json.getTensorParallelism(), json.getPipelineParallelism());
325393
LOG_INFO << "Loaded config from " << json_file_name.string();
326394
// auto dtype = model_config->getDataType();
327395

328396
// Currently doing fixed session config
329-
session_config.maxBatchSize = batchSize;
330-
session_config.maxBeamWidth = 1; // Fixed for simplicity
331-
session_config.maxSequenceLength = ctx_len;
332-
session_config.cudaGraphMode = true; // Fixed for simplicity
397+
session_config_.maxBatchSize = batch_size_;
398+
session_config_.maxBeamWidth = 1; // Fixed for simplicity
399+
session_config_.maxSequenceLength = ctx_len;
400+
session_config_.cudaGraphMode = true; // Fixed for simplicity
333401

334402
// Init gpt_session
335403
auto model_path = model_dir / json.engineFilename(world_config, model_id_);
336-
gpt_session = std::make_unique<GptSession>(session_config, *model_config, world_config, model_path.string(), logger);
404+
gpt_session = std::make_unique<GptSession>(session_config_, *model_config_, world_config, model_path.string(), logger_);
337405

338406
model_loaded_ = true;
339407
if (q_ == nullptr) {
@@ -365,8 +433,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr<Json::Value> json_body, std:
365433
gpt_session.reset();
366434
cortex_tokenizer.reset();
367435
q_.reset();
368-
model_config.reset();
369-
logger.reset();
436+
model_config_.reset();
437+
logger_.reset();
370438
model_loaded_ = false;
371439

372440
Json::Value json_resp;

0 commit comments

Comments
 (0)