Skip to content

Commit 0b2cbb2

Browse files
authored
[REFACTOR] Support latest include_usage and DebugOptions (#2417)
This PR refactors the mechanism of request end detection and also attaches the request metrics in response usage field. RequestResponse usage field: - include_usage can be passed to API. When include usage is on, metrics are now streamed back in the usage.extra - Changed debug_option parameter to extra_body, so they are fully compatible with OpenAI client - Support special requests in debug options, engine metrics are now streamed back via a special request We also change the FFI mechanism to detect response finish. Previously we keep track of number of stoppped streams. Now that the FFI always stream back the final chunk which have no choices and contains usage. We use the usage field to detect the final chunk. Code path are updated according. We also make Chat CLI a helper class that can be reused. iOS app now comes with stats support.
1 parent b18284b commit 0b2cbb2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1168
-701
lines changed

cpp/json_ffi/json_ffi_engine.cc

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,28 +57,33 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
5757
return false;
5858
}
5959
ChatCompletionRequest request = request_res.Unwrap();
60-
// get prompt: note, assistant was appended in the end.
61-
Result<std::vector<Data>> inputs_obj =
62-
CreatePrompt(this->conv_template_, request, this->model_config_, this->device_);
63-
if (inputs_obj.IsErr()) {
64-
err_ = inputs_obj.UnwrapErr();
65-
return false;
66-
}
67-
Array<Data> inputs = inputs_obj.Unwrap();
68-
69-
// generation_cfg
60+
Array<Data> inputs;
7061
Array<String> stop_strs;
71-
stop_strs.reserve(this->conv_template_.stop_str.size());
72-
for (const std::string& stop_str : this->conv_template_.stop_str) {
73-
stop_strs.push_back(stop_str);
74-
}
75-
if (request.stop.has_value()) {
76-
stop_strs.reserve(stop_strs.size() + request.stop.value().size());
77-
for (const std::string& stop_str : request.stop.value()) {
62+
bool is_special_request =
63+
(request.debug_config.has_value() &&
64+
request.debug_config.value().special_request != SpecialRequestKind::kNone);
65+
// special request does not have to go through prompt construction
66+
if (!is_special_request) {
67+
// get prompt: note, assistant was appended in the end.
68+
Result<std::vector<Data>> inputs_obj =
69+
CreatePrompt(this->conv_template_, request, this->model_config_, this->device_);
70+
if (inputs_obj.IsErr()) {
71+
err_ = inputs_obj.UnwrapErr();
72+
return false;
73+
}
74+
inputs = inputs_obj.Unwrap();
75+
76+
stop_strs.reserve(this->conv_template_.stop_str.size());
77+
for (const std::string& stop_str : this->conv_template_.stop_str) {
7878
stop_strs.push_back(stop_str);
7979
}
80+
if (request.stop.has_value()) {
81+
stop_strs.reserve(stop_strs.size() + request.stop.value().size());
82+
for (const std::string& stop_str : request.stop.value()) {
83+
stop_strs.push_back(stop_str);
84+
}
85+
}
8086
}
81-
8287
// create a generation config from request
8388
const auto& default_gen_cfg = default_generation_config_;
8489
auto gen_cfg = tvm::runtime::make_object<GenerationConfigNode>();
@@ -115,8 +120,6 @@ bool JSONFFIEngine::Abort(std::string request_id) {
115120

116121
std::string JSONFFIEngine::GetLastError() { return err_; }
117122

118-
std::string JSONFFIEngine::JSONMetrics() { return this->engine_->JSONMetrics(); }
119-
120123
void JSONFFIEngine::ExitBackgroundLoop() { this->engine_->ExitBackgroundLoop(); }
121124

122125
JSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); }
@@ -131,7 +134,6 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
131134
TVM_MODULE_VTABLE_ENTRY("chat_completion", &JSONFFIEngineImpl::ChatCompletion);
132135
TVM_MODULE_VTABLE_ENTRY("abort", &JSONFFIEngineImpl::Abort);
133136
TVM_MODULE_VTABLE_ENTRY("get_last_error", &JSONFFIEngineImpl::GetLastError);
134-
TVM_MODULE_VTABLE_ENTRY("json_metrics", &JSONFFIEngineImpl::JSONMetrics);
135137
TVM_MODULE_VTABLE_ENTRY("run_background_loop", &JSONFFIEngineImpl::RunBackgroundLoop);
136138
TVM_MODULE_VTABLE_ENTRY("run_background_stream_back_loop",
137139
&JSONFFIEngineImpl::RunBackgroundStreamBackLoop);
@@ -190,11 +192,35 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
190192

191193
String GetResponseFromStreamOutput(Array<RequestStreamOutput> delta_outputs) {
192194
std::unordered_map<std::string, std::vector<ChatCompletionStreamResponseChoice>> response_map;
195+
std::vector<picojson::value> request_final_usage_messages;
196+
std::string model = "json_ffi";
197+
193198
for (const auto& delta_output : delta_outputs) {
194199
std::string request_id = delta_output->request_id;
195200
if (response_map.find(request_id) == response_map.end()) {
196201
response_map[request_id] = std::vector<ChatCompletionStreamResponseChoice>();
197202
}
203+
204+
// build the final usage messages
205+
// invariant, we can always let other messages to come first
206+
// then the final usage messages, as final usage is always last
207+
if (delta_output->request_final_usage_json_str.defined()) {
208+
ChatCompletionStreamResponse response;
209+
response.id = request_id;
210+
response.model = model;
211+
response.system_fingerprint = "";
212+
std::string usage_json_str = delta_output->request_final_usage_json_str.value();
213+
picojson::value usage_json;
214+
std::string err = picojson::parse(usage_json, usage_json_str);
215+
if (!err.empty()) {
216+
err_ = err;
217+
} else {
218+
response.usage = usage_json;
219+
}
220+
request_final_usage_messages.push_back(picojson::value(response.AsJSON()));
221+
continue;
222+
}
223+
ICHECK_NE(delta_output->group_finish_reason.size(), 0);
198224
ChatCompletionStreamResponseChoice choice;
199225

200226
if (delta_output->group_finish_reason.size() != 1) {
@@ -232,13 +258,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
232258

233259
picojson::array response_arr;
234260
for (const auto& [request_id, choices] : response_map) {
261+
if (choices.size() == 0) continue;
235262
ChatCompletionStreamResponse response;
236263
response.id = request_id;
237264
response.choices = choices;
238265
response.model = "json_ffi"; // TODO: Return model name from engine (or from args)
239266
response.system_fingerprint = "";
240267
response_arr.push_back(picojson::value(response.AsJSON()));
241268
}
269+
for (auto&& item : request_final_usage_messages) {
270+
response_arr.emplace_back(std::move(item));
271+
}
242272
return picojson::value(response_arr).serialize();
243273
}
244274
};

cpp/json_ffi/json_ffi_engine.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ class JSONFFIEngine {
4141

4242
std::string GetLastError();
4343

44-
std::string JSONMetrics();
45-
4644
void ExitBackgroundLoop();
4745

4846
protected:

cpp/json_ffi/openai_api_protocol.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,21 @@ Result<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string&
387387
request.tools = tools;
388388
}
389389

390+
// debug_config
391+
Result<std::optional<picojson::object>> debug_config_opt_res =
392+
json::LookupOptionalWithResultReturn<picojson::object>(json_obj, "debug_config");
393+
if (debug_config_opt_res.IsErr()) {
394+
return TResult::Error(debug_config_opt_res.UnwrapErr());
395+
}
396+
auto debug_config_opt = debug_config_opt_res.Unwrap();
397+
if (debug_config_opt.has_value()) {
398+
Result<DebugConfig> debug_config_res = DebugConfig::FromJSON(debug_config_opt.value());
399+
if (debug_config_res.IsErr()) {
400+
return TResult::Error(debug_config_res.UnwrapErr());
401+
}
402+
request.debug_config = debug_config_res.Unwrap();
403+
}
404+
390405
// TODO: Other parameters
391406
return TResult::Ok(request);
392407
}
@@ -485,15 +500,20 @@ picojson::object ChatCompletionResponse::AsJSON() const {
485500
picojson::object ChatCompletionStreamResponse::AsJSON() const {
486501
picojson::object obj;
487502
obj["id"] = picojson::value(this->id);
503+
488504
picojson::array choices_arr;
489505
for (const auto& choice : this->choices) {
490506
choices_arr.push_back(picojson::value(choice.AsJSON()));
491507
}
492508
obj["choices"] = picojson::value(choices_arr);
509+
493510
obj["created"] = picojson::value((int64_t)this->created);
494511
obj["model"] = picojson::value(this->model);
495512
obj["system_fingerprint"] = picojson::value(this->system_fingerprint);
496513
obj["object"] = picojson::value(this->object);
514+
if (usage.has_value()) {
515+
obj["usage"] = usage.value();
516+
}
497517
return obj;
498518
}
499519

cpp/json_ffi/openai_api_protocol.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class ChatCompletionStreamResponse {
200200
std::string model;
201201
std::string system_fingerprint;
202202
std::string object = "chat.completion.chunk";
203+
std::optional<picojson::value> usage;
203204

204205
picojson::object AsJSON() const;
205206
};

cpp/serve/config.cc

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,42 @@ namespace mlc {
1919
namespace llm {
2020
namespace serve {
2121

22+
/****************** DebugConfig ******************/
23+
24+
Result<DebugConfig> DebugConfig::FromJSON(const picojson::object& config) {
25+
using TResult = Result<DebugConfig>;
26+
DebugConfig res;
27+
res.ignore_eos = json::LookupOrDefault<bool>(config, "ignore_eos", false);
28+
res.pinned_system_prompt = json::LookupOrDefault<bool>(config, "pinned_system_prompt", false);
29+
std::string special_request = json::LookupOrDefault<std::string>(config, "special_request", "");
30+
if (special_request.length() != 0) {
31+
if (special_request == "query_engine_metrics") {
32+
res.special_request = SpecialRequestKind::kQueryEngineMetrics;
33+
} else {
34+
return TResult::Error("Uknown special request " + special_request);
35+
}
36+
}
37+
return TResult::Ok(res);
38+
}
39+
40+
/**
41+
* \return serialized json value of the config.
42+
*/
43+
picojson::object DebugConfig::AsJSON() const {
44+
picojson::object config;
45+
config["ignore_eos"] = picojson::value(ignore_eos);
46+
config["pinned_system_prompt"] = picojson::value(pinned_system_prompt);
47+
switch (special_request) {
48+
case SpecialRequestKind::kQueryEngineMetrics: {
49+
config["special_request"] = picojson::value("query_engine_metrics");
50+
break;
51+
}
52+
case SpecialRequestKind::kNone:
53+
break;
54+
}
55+
return config;
56+
}
57+
2258
/****************** GenerationConfig ******************/
2359

2460
TVM_REGISTER_OBJECT_TYPE(GenerationConfigNode);
@@ -55,12 +91,10 @@ Result<GenerationConfig> GenerationConfig::Validate(GenerationConfig cfg) {
5591
return TResult::Ok(cfg);
5692
}
5793

58-
Result<GenerationConfig> GenerationConfig::FromJSON(String config_json_str,
94+
Result<GenerationConfig> GenerationConfig::FromJSON(const picojson::object& config,
5995
const GenerationConfig& default_config) {
6096
using TResult = Result<GenerationConfig>;
61-
picojson::object config = json::ParseToJSONObject(config_json_str);
6297
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();
63-
6498
n->n = json::LookupOrDefault<int64_t>(config, "n", default_config->n);
6599
n->temperature =
66100
json::LookupOrDefault<double>(config, "temperature", default_config->temperature);
@@ -144,18 +178,21 @@ Result<GenerationConfig> GenerationConfig::FromJSON(String config_json_str,
144178
// "debug_config" is for internal usage. Not the part of OpenAI API spec.
145179
std::optional<picojson::object> debug_config_obj =
146180
json::LookupOptional<picojson::object>(config, "debug_config");
181+
147182
if (debug_config_obj.has_value()) {
148-
n->debug_config.pinned_system_prompt =
149-
json::LookupOrDefault<bool>(debug_config_obj.value(), "pinned_system_prompt", false);
150-
n->debug_config.ignore_eos =
151-
json::LookupOrDefault<bool>(debug_config_obj.value(), "ignore_eos", false);
183+
Result<DebugConfig> debug_config_res = DebugConfig::FromJSON(debug_config_obj.value());
184+
if (debug_config_res.IsErr()) {
185+
return TResult::Error(debug_config_res.UnwrapErr());
186+
}
187+
n->debug_config = debug_config_res.Unwrap();
152188
}
153189
return Validate(GenerationConfig(n));
154190
}
155191

156192
GenerationConfig GenerationConfig::GetDefaultFromModelConfig(
157193
const picojson::object& model_config_json) {
158194
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();
195+
n->max_tokens = -1;
159196
n->temperature = json::LookupOrDefault<double>(model_config_json, "temperature", n->temperature);
160197
n->top_p = json::LookupOrDefault<double>(model_config_json, "top_p", n->top_p);
161198
n->frequency_penalty =
@@ -165,7 +202,7 @@ GenerationConfig GenerationConfig::GetDefaultFromModelConfig(
165202
return GenerationConfig(n);
166203
}
167204

168-
String GenerationConfigNode::AsJSONString() const {
205+
picojson::object GenerationConfigNode::AsJSON() const {
169206
picojson::object config;
170207
config["n"] = picojson::value(static_cast<int64_t>(this->n));
171208
config["temperature"] = picojson::value(this->temperature);
@@ -202,17 +239,8 @@ String GenerationConfigNode::AsJSONString() const {
202239
? picojson::value(this->response_format.schema.value())
203240
: picojson::value();
204241
config["response_format"] = picojson::value(response_format);
205-
206-
// Params for internal usage. Not the part of OpenAI API spec.
207-
{
208-
picojson::object debug_config_obj;
209-
debug_config_obj["pinned_system_prompt"] =
210-
picojson::value(this->debug_config.pinned_system_prompt);
211-
debug_config_obj["ignore_eos"] = picojson::value(this->debug_config.ignore_eos);
212-
config["debug_config"] = picojson::value(debug_config_obj);
213-
}
214-
215-
return picojson::value(config).serialize(true);
242+
config["debug_config"] = picojson::value(debug_config.AsJSON());
243+
return config;
216244
}
217245

218246
/****************** EngineConfig ******************/
@@ -349,11 +377,9 @@ struct ModelConfigLimits {
349377

350378
/*! \brief Convert the bytes to megabytes, keeping 3 decimals. */
351379
inline std::string BytesToMegabytesString(double bytes) {
352-
std::string str;
353-
str.resize(20);
354-
std::sprintf(&str[0], "%.3f", bytes / 1024 / 1024);
355-
str.resize(std::strlen(str.c_str()));
356-
return str;
380+
std::ostringstream os;
381+
os << std::setprecision(3) << std::fixed << (bytes / 1024 / 1024);
382+
return os.str();
357383
}
358384

359385
/*!

cpp/serve/config.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,53 @@ struct ResponseFormat {
3030
Optional<String> schema = NullOpt;
3131
};
3232

33+
enum class SpecialRequestKind : int {
34+
kNone = 0,
35+
kQueryEngineMetrics = 1,
36+
};
37+
3338
/*! \brief The debug configuration of a request. */
3439
class DebugConfig {
3540
public:
3641
bool ignore_eos = false;
3742
bool pinned_system_prompt = false;
43+
SpecialRequestKind special_request = SpecialRequestKind::kNone;
44+
45+
/*!
46+
* \brief Create debug config from JSON.
47+
* \param config_json The json string for generation config
48+
* \returns The converted result.
49+
*/
50+
static Result<DebugConfig> FromJSON(const picojson::object& config_json);
51+
52+
/**
53+
* \return serialized json value of the config.
54+
*/
55+
picojson::object AsJSON() const;
3856
};
3957

4058
/*! \brief The generation configuration of a request. */
4159
class GenerationConfigNode : public Object {
4260
public:
4361
int n = 1;
44-
double temperature = 0.8;
45-
double top_p = 0.95;
62+
double temperature = 1.0;
63+
double top_p = 1.0;
4664
double frequency_penalty = 0.0;
4765
double presence_penalty = 0.0;
4866
double repetition_penalty = 1.0;
4967
bool logprobs = false;
5068
int top_logprobs = 0;
5169
std::vector<std::pair<int, float>> logit_bias;
5270
int seed;
53-
54-
int max_tokens = 128;
71+
// -1 means infinite
72+
int max_tokens = -1;
5573
Array<String> stop_strs;
5674
std::vector<int> stop_token_ids;
5775

5876
ResponseFormat response_format;
5977
DebugConfig debug_config;
6078

61-
String AsJSONString() const;
79+
picojson::object AsJSON() const;
6280

6381
static constexpr const char* _type_key = "mlc.serve.GenerationConfig";
6482
static constexpr const bool _type_has_method_sequal_reduce = false;
@@ -76,10 +94,10 @@ class GenerationConfig : public ObjectRef {
7694

7795
/*!
7896
* \brief Create generation config from JSON.
79-
* \param config_json_str The json string for generation config
97+
* \param config_json The json string for generation config
8098
* \param default_config The default config
8199
*/
82-
static Result<GenerationConfig> FromJSON(String config_json_str,
100+
static Result<GenerationConfig> FromJSON(const picojson::object& config_json,
83101
const GenerationConfig& default_config);
84102

85103
/*! \brief Get the default generation config from the model config. */

0 commit comments

Comments
 (0)