21
21
using json = nlohmann::json;
22
22
using namespace tensorrtllm ;
23
23
24
+ namespace {
25
+ constexpr const int k200OK = 200 ;
26
+ constexpr const int k400BadRequest = 400 ;
27
+ constexpr const int k409Conflict = 409 ;
28
+ constexpr const int k500InternalServerError = 500 ;
29
+
30
+ // https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31
+ // stopWordsList
32
+ // 'im', '_' , 'end', '</s>', '<|im_end|>'
33
+ const std::vector<int32_t > kOpenhermesStopWords = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 };
34
+ const std::string kOhUserPrompt = " <|im_end|>\n <|im_start|>user\n " ;
35
+ const std::string kOhAiPrompt = " <|im_end|>\n <|im_start|>assistant\n " ;
36
+ const std::string kOhSystemPrompt = " <|im_start|>system\n " ;
37
+ const std::unordered_map<std::string, int > kOpenhermesTemplate = {{" <|im_end|>" , 32000 } , {" <|im_start|>" , 32001 }};
38
+
39
+ // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40
+ const std::vector<int32_t > kMistral_V0_3_StopWords
41
+ = {29560 , 17057 , 29561 , 3 , 29560 , 29516 , 17057 , 29561 , 4 , 2 , 3 , 4 , 8 , 9 , 10 , -1 , -1 , -1 , -1 , -1 };
42
+
43
+ enum class MistralTemplate : int32_t {
44
+ kBos = 1 ,
45
+ kEos = 2 ,
46
+ kBeginInst = 3 ,
47
+ kEndInst = 4
48
+ };
24
49
25
- constexpr const int k200OK = 200 ;
26
- constexpr const int k400BadRequest = 400 ;
27
- constexpr const int k409Conflict = 409 ;
28
- constexpr const int k500InternalServerError = 500 ;
29
-
50
+ // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
51
+ bool IsOpenhermes (const std::string& s) {
52
+ if (s.find (" mistral" ) != std::string::npos || s.find (" Mistral" ) != std::string::npos) {
53
+ return false ;
54
+ }
55
+ return true ;
56
+ }
57
+ }
30
58
TensorrtllmEngine::~TensorrtllmEngine () {}
31
59
32
60
void RemoveId (std::vector<int >& vec, int id) {
33
61
vec.erase (std::remove (vec.begin (), vec.end (), id), vec.end ());
34
62
}
35
63
36
- bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state) {
37
- if (infer_state->IsComplete ()) {
64
+ bool HandleMatch (std::string const & rew_text,
65
+ std::shared_ptr<InferenceState> infer_state,
66
+ std::function<void (Json::Value&&, Json::Value&&)> cb,
67
+ bool is_openhermes) {
68
+ if (infer_state->IsComplete (is_openhermes)) {
38
69
return false ;
39
70
}
40
71
if (infer_state->stop_word_match_len == 0 ) {
41
- if (rew_text.find (' <' ) != std::string::npos) { // Found "<" anywhere in the text
72
+ if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
73
+ (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
42
74
infer_state->stop_word_match_len ++; // Move to next state
43
- infer_state->prev_text = rew_text;
44
75
return true ;
45
76
}
46
- }
47
- else if (rew_text == infer_state->sequence [infer_state->stop_word_match_len ]) {
77
+ } else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len )) {
48
78
infer_state->stop_word_match_len ++; // Move to next state
49
- infer_state->prev_text = rew_text;
50
79
return true ;
51
- }
52
- else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence [0 ]) {
80
+ } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u )) {
53
81
infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
54
- infer_state->prev_text = rew_text;
55
82
return true ;
56
- }
57
- else {
83
+ } else {
58
84
infer_state->Reset ();
59
85
return false ; // Reset to start if sequence breaks
60
86
}
@@ -67,19 +93,21 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
67
93
}
68
94
69
95
GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList () {
70
- std::vector<int32_t > stop_words_tokens
71
- = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 }; // Extend with -1 for increased length
72
- return gpt_session->getBufferManager ().copyFrom (stop_words_tokens, ITensor::makeShape ({1 , 2 , 5 }), MemoryType::kGPU );
96
+ if (is_openhermes_) {
97
+ return gpt_session->getBufferManager ().copyFrom (kOpenhermesStopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kOpenhermesStopWords .size ()/2 )}), MemoryType::kGPU );
98
+ } else {
99
+ return gpt_session->getBufferManager ().copyFrom (kMistral_V0_3_StopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kMistral_V0_3_StopWords .size ()/2 )}), MemoryType::kGPU );
100
+ }
73
101
}
74
102
75
103
GenerationInput TensorrtllmEngine::CreateGenerationInput (std::vector<int32_t > input_ids_host) {
76
104
int input_len = input_ids_host.size ();
77
- std::vector<int32_t > input_lengths_host (batchSize , input_len);
105
+ std::vector<int32_t > input_lengths_host (batch_size_ , input_len);
78
106
GenerationInput::TensorPtr input_lengths
79
- = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batchSize }), MemoryType::kGPU );
107
+ = gpt_session->getBufferManager ().copyFrom (input_lengths_host, ITensor::makeShape ({batch_size_ }), MemoryType::kGPU );
80
108
GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager ().copyFrom (
81
- input_ids_host, ITensor::makeShape ({batchSize , input_len}), MemoryType::kGPU );
82
- GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config ->usePackedInput ()};
109
+ input_ids_host, ITensor::makeShape ({batch_size_ , input_len}), MemoryType::kGPU );
110
+ GenerationInput generation_input{0 , 0 , input_ids, input_lengths, model_config_ ->usePackedInput ()};
83
111
generation_input.stopWordsList = GetTensorChatMLStopWordList ();
84
112
85
113
LOG_INFO << " Create generation input successfully" ;
@@ -102,27 +130,34 @@ void InferenceThread(
102
130
TensorrtllmEngine* self,
103
131
SamplingConfig sampling_config,
104
132
int input_len,
105
- int outputLen) {
133
+ int outputLen, bool is_openhermes ) {
106
134
107
135
// Input preparation
108
136
LOG_INFO << " Inference thread started" ;
109
137
GenerationInput generation_input = self->CreateGenerationInput (input_ids_host);
110
138
GenerationOutput generation_output = self->CreateGenerationOutput ();
111
139
112
140
// Define the callback to stream each generated token
113
- generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
141
+ generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes ](
114
142
GenerationOutput::TensorPtr const & output_ids, SizeType step, bool finished) {
115
- LOG_INFO << " Generating tokenizer in thread" ;
143
+ // LOG_INFO << "Generating tokenizer in thread";
116
144
// Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
117
145
int output_length = output_ids->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
118
146
// Copy output IDs from GPU to host for printing
119
147
std::vector<int32_t > output_idsHost (output_length);
120
148
self->gpt_session ->getBufferManager ().copy (*output_ids, output_idsHost.data (), MemoryType::kCPU );
121
149
// Find the last non-zero value in the output IDs starting from the end of the input sequence
122
150
std::vector<int > output_idsHostDecode (output_idsHost.begin () + input_len, output_idsHost.end ());
151
+
123
152
RemoveId (output_idsHostDecode, 0 );
124
- RemoveId (output_idsHostDecode, 32000 );
125
- RemoveId (output_idsHostDecode, 32001 );
153
+ if (is_openhermes) {
154
+ for (auto const & [_, v]: kOpenhermesTemplate ) {
155
+ RemoveId (output_idsHostDecode, v);
156
+ }
157
+ } else {
158
+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kBeginInst ));
159
+ RemoveId (output_idsHostDecode, static_cast <int32_t >(MistralTemplate::kEndInst ));
160
+ }
126
161
std::string text = self->cortex_tokenizer ->Decode (output_idsHostDecode);
127
162
128
163
if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size ()) {
@@ -192,29 +227,47 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function<void(Json::Value&&, Json:
192
227
193
228
void TensorrtllmEngine::HandleChatCompletion (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
194
229
inferences::ChatCompletionRequest request = inferences::fromJson (json_body);
195
- std::string formatted_input = pre_prompt ;
230
+ std::string formatted_input = pre_prompt_ ;
196
231
nlohmann::json data;
197
232
// data["stream"] = completion.stream;
198
233
// data["n_predict"] = completion.max_tokens;
199
234
data[" presence_penalty" ] = request.presence_penalty ;
200
235
Json::Value const & messages = request.messages ;
201
236
237
+ // tokens for Mistral v0.3
238
+ // TODO(sang): too much hard code here, need to refactor it soon
239
+ std::vector<int32_t > tokens = {static_cast <int32_t >(MistralTemplate::kBos )};
240
+
202
241
// Format the input from user
242
+ int msg_count = 0 ;
203
243
for (auto const & message : messages) {
204
244
std::string input_role = message[" role" ].asString ();
205
245
std::string role;
206
246
if (input_role == " user" ) {
207
- role = user_prompt ;
247
+ role = user_prompt_ ;
208
248
std::string content = message[" content" ].asString ();
209
249
formatted_input += role + content;
250
+ if (!is_openhermes_) {
251
+ auto new_tokens = cortex_tokenizer->Encode (content);
252
+ new_tokens.insert (new_tokens.begin (), static_cast <int32_t >(MistralTemplate::kBeginInst ));
253
+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEndInst ));
254
+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
255
+ }
210
256
}
211
257
else if (input_role == " assistant" ) {
212
- role = ai_prompt ;
258
+ role = ai_prompt_ ;
213
259
std::string content = message[" content" ].asString ();
214
260
formatted_input += role + content;
261
+ if (!is_openhermes_) {
262
+ auto new_tokens = cortex_tokenizer->Encode (content);
263
+ if (msg_count == messages.size () - 1 ) {
264
+ new_tokens.push_back (static_cast <int32_t >(MistralTemplate::kEos ));
265
+ }
266
+ tokens.insert (tokens.end (), new_tokens.begin (), new_tokens.end ());
267
+ }
215
268
}
216
269
else if (input_role == " system" ) {
217
- role = system_prompt ;
270
+ role = system_prompt_ ;
218
271
std::string content = message[" content" ].asString ();
219
272
formatted_input = role + content + formatted_input;
220
273
}
@@ -223,13 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
223
276
std::string content = message[" content" ].asString ();
224
277
formatted_input += role + content;
225
278
}
279
+ msg_count++;
226
280
}
227
- formatted_input += ai_prompt;
281
+ formatted_input += ai_prompt_;
282
+ // LOG_INFO << formatted_input;
228
283
// Format the input from user
229
284
230
285
std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
231
286
232
- std::vector<int32_t > input_ids_host = cortex_tokenizer->Encode (formatted_input);
287
+ std::vector<int32_t > input_ids_host;
288
+ if (is_openhermes_) {
289
+ input_ids_host = cortex_tokenizer->Encode (formatted_input);
290
+ } else {
291
+ input_ids_host = tokens;
292
+ }
293
+
233
294
int const input_len = input_ids_host.size ();
234
295
int const outputLen = request.max_tokens - input_len;
235
296
@@ -243,23 +304,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
243
304
sampling_config.repetitionPenalty = std::vector{request.frequency_penalty };
244
305
// Input preparation
245
306
246
- std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen);
307
+ std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen, is_openhermes_ );
247
308
inference_thread.detach (); // Detach the thread to allow it to run independently
248
309
249
- q_->runTaskInQueue ([cb = std::move (callback), infer_state]() {
310
+ q_->runTaskInQueue ([this , cb = std::move (callback), infer_state]() {
311
+ // std::string res_str;
250
312
LOG_INFO << " Preparing to run inference task queue..." ;
251
313
while (true ) { // Continuously check if the queue is not empty
252
314
std::unique_lock<std::mutex> lock (infer_state->queue_mutex ); // Lock the queue for exclusive access
253
315
if (!infer_state->texts_to_stream .empty ()) {
254
316
std::string rew_text = infer_state->texts_to_stream .front ();
317
+ // res_str += rew_text;
255
318
infer_state->texts_to_stream .pop ();
256
- if (HandleMatch (rew_text, infer_state) && rew_text != " [DONE]" ) {
319
+ if (HandleMatch (rew_text, infer_state, cb, is_openhermes_ ) && rew_text != " [DONE]" ) {
257
320
continue ;
258
321
};
259
322
260
323
if (rew_text == " [DONE]" ) {
261
324
const std::string str
262
- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , " " , " stop" )
325
+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , " " , " stop" )
263
326
+ " \n\n " + " data: [DONE]" + " \n\n " ;
264
327
265
328
infer_state->is_finished = true ;
@@ -275,10 +338,10 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
275
338
break ;
276
339
}
277
340
const std::string text_to_stream
278
- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , rew_text) + " \n\n " ;
341
+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , rew_text) + " \n\n " ;
279
342
280
343
lock.unlock (); // Unlock as soon as possible
281
- infer_state-> prev_text = rew_text;
344
+ // std::cout << rew_text;
282
345
283
346
Json::Value resp_data;
284
347
resp_data[" data" ] = text_to_stream;
@@ -293,6 +356,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
293
356
lock.unlock ();
294
357
}
295
358
}
359
+ // LOG_INFO << res_str;
296
360
});
297
361
298
362
LOG_INFO << " Inference completed" ;
@@ -302,16 +366,20 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
302
366
void TensorrtllmEngine::LoadModel (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
303
367
model::LoadModelRequest request = model::fromJson (json_body);
304
368
std::filesystem::path model_dir = request.model_path ;
369
+ is_openhermes_ = IsOpenhermes (request.model_path );
305
370
306
371
int ctx_len = request.ctx_len ;
307
- this ->user_prompt = request.user_prompt ;
308
- this ->ai_prompt = request.ai_prompt ;
309
- this ->system_prompt = request.system_prompt ;
310
- this ->model_id_ = GetModelId (*json_body);
372
+ // We only support 2 models for now, it is ugly but it works :(
373
+ if (is_openhermes_) {
374
+ user_prompt_ = request.user_prompt .empty () ? kOhUserPrompt : request.user_prompt ;
375
+ ai_prompt_ = request.ai_prompt .empty () ? kOhAiPrompt : request.ai_prompt ;
376
+ system_prompt_ = request.system_prompt .empty () ? kOhSystemPrompt : request.system_prompt ;
377
+ }
378
+ model_id_ = GetModelId (*json_body);
311
379
312
- logger = std::make_shared<TllmLogger>();
313
- logger ->setLevel (nvinfer1::ILogger::Severity::kINFO );
314
- initTrtLlmPlugins (logger .get ());
380
+ logger_ = std::make_shared<TllmLogger>();
381
+ logger_ ->setLevel (nvinfer1::ILogger::Severity::kINFO );
382
+ initTrtLlmPlugins (logger_ .get ());
315
383
316
384
std::filesystem::path tokenizer_model_name = model_dir / " tokenizer.model" ;
317
385
cortex_tokenizer = std::make_unique<Tokenizer>(tokenizer_model_name.string ());
@@ -320,20 +388,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
320
388
std::filesystem::path json_file_name = model_dir / " config.json" ;
321
389
auto json = GptJsonConfig::parse (json_file_name);
322
390
auto config = json.getModelConfig ();
323
- model_config = std::make_unique<GptModelConfig>(config);
391
+ model_config_ = std::make_unique<GptModelConfig>(config);
324
392
auto world_config = WorldConfig::mpi (1 , json.getTensorParallelism (), json.getPipelineParallelism ());
325
393
LOG_INFO << " Loaded config from " << json_file_name.string ();
326
394
// auto dtype = model_config->getDataType();
327
395
328
396
// Currently doing fixed session config
329
- session_config .maxBatchSize = batchSize ;
330
- session_config .maxBeamWidth = 1 ; // Fixed for simplicity
331
- session_config .maxSequenceLength = ctx_len;
332
- session_config .cudaGraphMode = true ; // Fixed for simplicity
397
+ session_config_ .maxBatchSize = batch_size_ ;
398
+ session_config_ .maxBeamWidth = 1 ; // Fixed for simplicity
399
+ session_config_ .maxSequenceLength = ctx_len;
400
+ session_config_ .cudaGraphMode = true ; // Fixed for simplicity
333
401
334
402
// Init gpt_session
335
403
auto model_path = model_dir / json.engineFilename (world_config, model_id_);
336
- gpt_session = std::make_unique<GptSession>(session_config , *model_config , world_config, model_path.string (), logger );
404
+ gpt_session = std::make_unique<GptSession>(session_config_ , *model_config_ , world_config, model_path.string (), logger_ );
337
405
338
406
model_loaded_ = true ;
339
407
if (q_ == nullptr ) {
@@ -365,8 +433,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr<Json::Value> json_body, std:
365
433
gpt_session.reset ();
366
434
cortex_tokenizer.reset ();
367
435
q_.reset ();
368
- model_config .reset ();
369
- logger .reset ();
436
+ model_config_ .reset ();
437
+ logger_ .reset ();
370
438
model_loaded_ = false ;
371
439
372
440
Json::Value json_resp;
0 commit comments