@@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments
41
41
self ,
42
42
host : str ,
43
43
port : int ,
44
+ backend : str ,
44
45
timeout : Optional [float ] = None ,
45
46
include_server_metrics : bool = False ,
47
+ no_debug_config : bool = False ,
46
48
) -> None :
47
49
super ().__init__ (include_server_metrics = include_server_metrics )
48
50
49
51
import aiohttp # pylint: disable=import-outside-toplevel,import-error
50
52
53
+ self .backend = backend
51
54
self .timeout = timeout
52
55
self .client : aiohttp .ClientSession = None
53
56
self .url = f"http://{ host } :{ port } /v1/chat/completions"
54
57
self .headers = {"Content-Type" : "application/json" }
55
58
if os .getenv ("MLC_LLM_API_KEY" ):
56
59
self .headers ["Authorization" ] = f"Bearer { os .getenv ('MLC_LLM_API_KEY' )} "
60
+ self .no_debug_config = no_debug_config
57
61
58
62
async def __aenter__ (self ) -> Self :
59
63
import aiohttp # pylint: disable=import-outside-toplevel,import-error
@@ -80,13 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
80
84
and request_record .chat_cmpl .debug_config .ignore_eos
81
85
):
82
86
payload ["ignore_eos" ] = True
87
+ if not self .no_debug_config :
88
+ payload ["debug_config" ] = {"ignore_eos" : True }
83
89
84
- print (payload )
85
-
86
- if "response_format" in payload and "json_schema" in payload ["response_format" ]:
87
- payload ["response_format" ]["schema" ] = payload ["response_format" ]["json_schema" ]
88
- payload ["response_format" ].pop ("json_schema" )
89
-
90
+ if self .backend == "vllm" :
91
+ if payload ["debug_config" ] and "ignore_eos" in payload ["debug_config" ]:
92
+ payload ["ignore_eos" ] = payload ["debug_config" ]["ignore_eos" ]
93
+ payload .pop ("debug_config" )
94
+ if "response_format" in payload :
95
+ if "json_schema" in payload ["response_format" ]:
96
+ payload ["guided_json" ] = json .loads (payload ["response_format" ]["json_schema" ])
97
+ payload ["guided_decoding_backend" ] = "outlines"
98
+ payload .pop ("response_format" )
99
+ elif self .backend == "llama.cpp" :
100
+ if "response_format" in payload and "schema" in payload ["response_format" ]:
101
+ payload ["response_format" ]["schema" ] = json .loads (
102
+ payload ["response_format" ]["json_schema" ]
103
+ )
104
+ payload ["response_format" ].pop ("json_schema" )
105
+ else :
106
+ if "response_format" in payload and "json_schema" in payload ["response_format" ]:
107
+ payload ["response_format" ]["schema" ] = payload ["response_format" ]["json_schema" ]
108
+ payload ["response_format" ].pop ("json_schema" )
90
109
generated_text = ""
91
110
first_chunk_output_str = ""
92
111
time_to_first_token_s = None
@@ -447,19 +466,33 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man
447
466
"sglang" ,
448
467
"tensorrt-llm" ,
449
468
"vllm" ,
469
+ "vllm-chat" ,
470
+ "llama.cpp-chat" ,
450
471
]
451
472
452
473
453
474
def create_api_endpoint (args : argparse .Namespace ) -> APIEndPoint :
454
475
"""Create an API endpoint instance with regard to the specified endpoint kind."""
455
476
if args .api_endpoint in ["openai" , "mlc" , "sglang" ]:
456
477
return OpenAIEndPoint (args .host , args .port , args .timeout , args .include_server_metrics )
457
- if args .api_endpoint == "vllm" :
478
+ if args .api_endpoint in [ "vllm" , "llama.cpp" ] :
458
479
return OpenAIEndPoint (
459
480
args .host , args .port , args .timeout , include_server_metrics = False , no_debug_config = True
460
481
)
461
482
if args .api_endpoint == "openai-chat" :
462
- return OpenAIChatEndPoint (args .host , args .port , args .timeout , args .include_server_metrics )
483
+ return OpenAIChatEndPoint (
484
+ args .host , args .port , args .timeout , args .api_endpoint , args .include_server_metrics
485
+ )
486
+ if args .api_endpoint in ["vllm-chat" , "llama.cpp-chat" ]:
487
+ return OpenAIChatEndPoint (
488
+ args .host ,
489
+ args .port ,
490
+ args .api_endpoint [:- 5 ],
491
+ args .timeout ,
492
+ include_server_metrics = False ,
493
+ no_debug_config = True ,
494
+ )
495
+
463
496
if args .api_endpoint == "tensorrt-llm" :
464
497
return TensorRTLLMEndPoint (args .host , args .port , args .timeout )
465
498
raise ValueError (f'Unrecognized endpoint "{ args .api_endpoint } "' )
0 commit comments