@@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments
41
41
self ,
42
42
host : str ,
43
43
port : int ,
44
+ backend : str ,
44
45
timeout : Optional [float ] = None ,
45
46
include_server_metrics : bool = False ,
47
+ no_debug_config : bool = False ,
46
48
) -> None :
47
49
super ().__init__ (include_server_metrics = include_server_metrics )
48
50
49
51
import aiohttp # pylint: disable=import-outside-toplevel,import-error
50
52
53
+ self .backend = backend
51
54
self .timeout = timeout
52
55
self .client : aiohttp .ClientSession = None
53
56
self .url = f"http://{ host } :{ port } /v1/chat/completions"
54
57
self .headers = {"Content-Type" : "application/json" }
55
58
if os .getenv ("MLC_LLM_API_KEY" ):
56
59
self .headers ["Authorization" ] = f"Bearer { os .getenv ('MLC_LLM_API_KEY' )} "
60
+ self .no_debug_config = no_debug_config
57
61
58
62
async def __aenter__ (self ) -> Self :
59
63
import aiohttp # pylint: disable=import-outside-toplevel,import-error
@@ -67,7 +71,7 @@ async def __aexit__(self, exc_type, exc_value, tb) -> None:
67
71
async def __call__ ( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
68
72
self , request_record : RequestRecord
69
73
) -> RequestRecord :
70
- payload = request_record .chat_cmpl .model_dump ()
74
+ payload = request_record .chat_cmpl .model_dump (exclude_unset = True , exclude_none = True )
71
75
if self .timeout is not None and "timeout" not in payload :
72
76
payload ["timeout" ] = self .timeout
73
77
if self .include_server_metrics :
@@ -80,7 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
80
84
and request_record .chat_cmpl .debug_config .ignore_eos
81
85
):
82
86
payload ["ignore_eos" ] = True
87
+ if not self .no_debug_config :
88
+ payload ["debug_config" ] = {"ignore_eos" : True }
83
89
90
+ if self .backend == "vllm" :
91
+ if payload ["debug_config" ] and "ignore_eos" in payload ["debug_config" ]:
92
+ payload ["ignore_eos" ] = payload ["debug_config" ]["ignore_eos" ]
93
+ payload .pop ("debug_config" )
94
+ if "response_format" in payload :
95
+ if "json_schema" in payload ["response_format" ]:
96
+ payload ["guided_json" ] = json .loads (payload ["response_format" ]["json_schema" ])
97
+ payload ["guided_decoding_backend" ] = "outlines"
98
+ payload .pop ("response_format" )
99
+ elif self .backend == "llama.cpp" :
100
+ if "response_format" in payload and "schema" in payload ["response_format" ]:
101
+ payload ["response_format" ]["schema" ] = json .loads (
102
+ payload ["response_format" ]["json_schema" ]
103
+ )
104
+ payload ["response_format" ].pop ("json_schema" )
105
+ else :
106
+ if "response_format" in payload and "json_schema" in payload ["response_format" ]:
107
+ payload ["response_format" ]["schema" ] = payload ["response_format" ]["json_schema" ]
108
+ payload ["response_format" ].pop ("json_schema" )
84
109
generated_text = ""
85
110
first_chunk_output_str = ""
86
111
time_to_first_token_s = None
@@ -441,19 +466,33 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man
441
466
"sglang" ,
442
467
"tensorrt-llm" ,
443
468
"vllm" ,
469
+ "vllm-chat" ,
470
+ "llama.cpp-chat" ,
444
471
]
445
472
446
473
447
474
def create_api_endpoint (args : argparse .Namespace ) -> APIEndPoint :
448
475
"""Create an API endpoint instance with regard to the specified endpoint kind."""
449
476
if args .api_endpoint in ["openai" , "mlc" , "sglang" ]:
450
477
return OpenAIEndPoint (args .host , args .port , args .timeout , args .include_server_metrics )
451
- if args .api_endpoint == "vllm" :
478
+ if args .api_endpoint in [ "vllm" , "llama.cpp" ] :
452
479
return OpenAIEndPoint (
453
480
args .host , args .port , args .timeout , include_server_metrics = False , no_debug_config = True
454
481
)
455
482
if args .api_endpoint == "openai-chat" :
456
- return OpenAIChatEndPoint (args .host , args .port , args .timeout , args .include_server_metrics )
483
+ return OpenAIChatEndPoint (
484
+ args .host , args .port , args .timeout , args .api_endpoint , args .include_server_metrics
485
+ )
486
+ if args .api_endpoint in ["vllm-chat" , "llama.cpp-chat" ]:
487
+ return OpenAIChatEndPoint (
488
+ args .host ,
489
+ args .port ,
490
+ args .api_endpoint [:- 5 ],
491
+ args .timeout ,
492
+ include_server_metrics = False ,
493
+ no_debug_config = True ,
494
+ )
495
+
457
496
if args .api_endpoint == "tensorrt-llm" :
458
497
return TensorRTLLMEndPoint (args .host , args .port , args .timeout )
459
498
raise ValueError (f'Unrecognized endpoint "{ args .api_endpoint } "' )
0 commit comments