Skip to content

Commit 5a9ba0c

Browse files
authored
Add the --served-model-name flag (#69)
Add support for --served-model-name flag to expose custom model names via API endpoints. The flag allows specifying model aliases that appear in API responses, Prometheus metrics, and /v1/models endpoint while maintaining LoRA adapter names unchanged. Signed-off-by: Brent Salisbury <bsalisbu@redhat.com>
1 parent b590f13 commit 5a9ba0c

File tree

2 files changed

+76
-33
lines changed

2 files changed

+76
-33
lines changed

pkg/llm-d-inference-sim/metrics.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,19 @@ func (s *VllmSimulator) createAndRegisterPrometheus() error {
9797

9898
// setInitialPrometheusMetrics send default values to prometheus
9999
func (s *VllmSimulator) setInitialPrometheusMetrics() {
100+
modelName := s.getDisplayedModelName(s.model)
100101
s.loraInfo.WithLabelValues(
101102
strconv.Itoa(s.maxLoras),
102103
"",
103104
"").Set(float64(time.Now().Unix()))
104105

105106
s.nRunningReqs = 0
106107
s.runningRequests.WithLabelValues(
107-
s.model).Set(float64(s.nRunningReqs))
108+
modelName).Set(float64(s.nRunningReqs))
108109
s.waitingRequests.WithLabelValues(
109-
s.model).Set(float64(0))
110+
modelName).Set(float64(0))
110111
s.kvCacheUsagePercentage.WithLabelValues(
111-
s.model).Set(float64(0))
112+
modelName).Set(float64(0))
112113
}
113114

114115
// reportLoras sets information about loaded LoRA adapters
@@ -135,7 +136,7 @@ func (s *VllmSimulator) reportRunningRequests() {
135136
if s.runningRequests != nil {
136137
nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs))
137138
s.runningRequests.WithLabelValues(
138-
s.model).Set(float64(nRunningReqs))
139+
s.getDisplayedModelName(s.model)).Set(float64(nRunningReqs))
139140
}
140141
}
141142

@@ -144,6 +145,6 @@ func (s *VllmSimulator) reportWaitingRequests() {
144145
if s.waitingRequests != nil {
145146
nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs))
146147
s.waitingRequests.WithLabelValues(
147-
s.model).Set(float64(nWaitingReqs))
148+
s.getDisplayedModelName(s.model)).Set(float64(nWaitingReqs))
148149
}
149150
}

pkg/llm-d-inference-sim/simulator.go

Lines changed: 70 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ type VllmSimulator struct {
7272
mode string
7373
// model defines the current base model name
7474
model string
75+
// one or many names exposed by the API
76+
servedModelNames []string
7577
// loraAdaptors contains list of LoRA available adaptors
7678
loraAdaptors sync.Map
7779
// maxLoras defines maximum number of loaded loras
@@ -150,6 +152,7 @@ func (s *VllmSimulator) parseCommandParams() error {
150152
f.IntVar(&s.interTokenLatency, "inter-token-latency", 0, "Time to generate one token (in milliseconds)")
151153
f.IntVar(&s.timeToFirstToken, "time-to-first-token", 0, "Time to first token (in milliseconds)")
152154
f.StringVar(&s.model, "model", "", "Currently 'loaded' model")
155+
f.StringSliceVar(&s.servedModelNames, "served-model-name", nil, "Model names exposed by the API (comma or space-separated)")
153156
var lorasStr string
154157
f.StringVar(&lorasStr, "lora", "", "List of LoRA adapters, separated by comma")
155158
f.IntVar(&s.maxLoras, "max-loras", 1, "Maximum number of LoRAs in a single batch")
@@ -169,6 +172,14 @@ func (s *VllmSimulator) parseCommandParams() error {
169172
if s.model == "" {
170173
return errors.New("model parameter is empty")
171174
}
175+
176+
// Upstream vLLM behaviour: when --served-model-name is not provided,
177+
// it falls back to using the value of --model as the single public name
178+
// returned by the API and exposed in Prometheus metrics.
179+
if len(s.servedModelNames) == 0 {
180+
s.servedModelNames = []string{s.model}
181+
}
182+
172183
if s.mode != modeEcho && s.mode != modeRandom {
173184
return fmt.Errorf("invalid mode '%s', valid values are 'random' and 'echo'", s.mode)
174185
}
@@ -301,10 +312,11 @@ func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) {
301312

302313
// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
303314
func (s *VllmSimulator) isValidModel(model string) bool {
304-
if model == s.model {
305-
return true
315+
for _, name := range s.servedModelNames {
316+
if model == name {
317+
return true
318+
}
306319
}
307-
308320
for _, lora := range s.getLoras() {
309321
if model == lora {
310322
return true
@@ -372,6 +384,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
372384

373385
req := reqCtx.completionReq
374386
model := req.getModel()
387+
displayModel := s.getDisplayedModelName(model)
388+
375389
if s.isLora(model) {
376390
// if current request's model is LoRA, add it to the list of running loras
377391
value, ok := s.runningLoras.Load(model)
@@ -397,8 +411,11 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
397411
var err error
398412
var toolCalls []toolCall
399413
var completionTokens int
400-
if reqCtx.isChatCompletion && req.getToolChoice() != toolChoiceNone && req.getTools() != nil {
401-
toolCalls, finishReason, completionTokens, err = createToolCalls(req.getTools(), req.getToolChoice())
414+
if reqCtx.isChatCompletion &&
415+
req.getToolChoice() != toolChoiceNone &&
416+
req.getTools() != nil {
417+
toolCalls, finishReason, completionTokens, err =
418+
createToolCalls(req.getTools(), req.getToolChoice())
402419
}
403420
if toolCalls == nil && err == nil {
404421
// Either no tool calls were defined, or we randomly chose not to create tool calls,
@@ -426,10 +443,20 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
426443
usageDataToSend = &usageData
427444
}
428445
s.sendStreamingResponse(
429-
&streamingContext{ctx: reqCtx.httpReqCtx, isChatCompletion: reqCtx.isChatCompletion, model: model},
430-
responseTokens, toolCalls, finishReason, usageDataToSend)
446+
&streamingContext{
447+
ctx: reqCtx.httpReqCtx,
448+
isChatCompletion: reqCtx.isChatCompletion,
449+
model: displayModel,
450+
},
451+
responseTokens, toolCalls, finishReason, usageDataToSend,
452+
)
431453
} else {
432-
s.sendResponse(reqCtx.isChatCompletion, reqCtx.httpReqCtx, responseTokens, toolCalls, model, finishReason,
454+
s.sendResponse(reqCtx.isChatCompletion,
455+
reqCtx.httpReqCtx,
456+
responseTokens,
457+
toolCalls,
458+
displayModel,
459+
finishReason,
433460
&usageData)
434461
}
435462
}
@@ -444,8 +471,8 @@ func (s *VllmSimulator) responseSentCallback(model string) {
444471
atomic.AddInt64(&(s.nRunningReqs), -1)
445472
s.reportRunningRequests()
446473

447-
if model == s.model {
448-
// this is the base model - do not continue
474+
// Only LoRA models require reference-count handling.
475+
if !s.isLora(model) {
449476
return
450477
}
451478

@@ -515,15 +542,16 @@ func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) {
515542
// as defined by isChatCompletion
516543
// respTokens - tokenized content to be sent in the response
517544
// toolCalls - tool calls to be sent in the response
518-
// model - model name
519545
// finishReason - a pointer to string that represents finish reason, can be nil or stop or length, ...
520546
// usageData - usage (tokens statistics) for this response
521-
func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []toolCall, model string,
522-
finishReason *string, usageData *usage) completionResponse {
547+
// modelName - display name returned to the client and used in metrics. It is either the first alias
548+
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
549+
func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []toolCall,
550+
finishReason *string, usageData *usage, modelName string) completionResponse {
523551
baseResp := baseCompletionResponse{
524552
ID: chatComplIDPrefix + uuid.NewString(),
525553
Created: time.Now().Unix(),
526-
Model: model,
554+
Model: modelName,
527555
Usage: usageData,
528556
}
529557
baseChoice := baseResponseChoice{Index: 0, FinishReason: finishReason}
@@ -555,12 +583,13 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
555583
// according the value of isChatCompletion
556584
// respTokens - tokenized content to be sent in the response
557585
// toolCalls - tool calls to be sent in the response
558-
// model - model name
586+
// modelName - display name returned to the client and used in metrics. It is either the first alias
587+
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
559588
// finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools
560589
// usageData - usage (tokens statistics) for this response
561590
func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.RequestCtx, respTokens []string, toolCalls []toolCall,
562-
model string, finishReason string, usageData *usage) {
563-
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, model, &finishReason, usageData)
591+
modelName string, finishReason string, usageData *usage) {
592+
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName)
564593

565594
data, err := json.Marshal(resp)
566595
if err != nil {
@@ -578,32 +607,35 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
578607
ctx.Response.Header.SetStatusCode(fasthttp.StatusOK)
579608
ctx.Response.SetBody(data)
580609

581-
s.responseSentCallback(model)
610+
s.responseSentCallback(modelName)
582611
}
583612

584613
// createModelsResponse creates and returns ModelResponse for the current state, returned array of models contains the base model + LoRA adapters if exist
585614
func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse {
586615
modelsResp := vllmapi.ModelsResponse{Object: "list", Data: []vllmapi.ModelsResponseModelInfo{}}
587616

588-
// add base model's info
589-
modelsResp.Data = append(modelsResp.Data, vllmapi.ModelsResponseModelInfo{
590-
ID: s.model,
591-
Object: vllmapi.ObjectModel,
592-
Created: time.Now().Unix(),
593-
OwnedBy: "vllm",
594-
Root: s.model,
595-
Parent: nil,
596-
})
617+
// Advertise every public model alias
618+
for _, alias := range s.servedModelNames {
619+
modelsResp.Data = append(modelsResp.Data, vllmapi.ModelsResponseModelInfo{
620+
ID: alias,
621+
Object: vllmapi.ObjectModel,
622+
Created: time.Now().Unix(),
623+
OwnedBy: "vllm",
624+
Root: alias,
625+
Parent: nil,
626+
})
627+
}
597628

598629
// add LoRA adapter's info
630+
parent := s.servedModelNames[0]
599631
for _, lora := range s.getLoras() {
600632
modelsResp.Data = append(modelsResp.Data, vllmapi.ModelsResponseModelInfo{
601633
ID: lora,
602634
Object: vllmapi.ObjectModel,
603635
Created: time.Now().Unix(),
604636
OwnedBy: "vllm",
605637
Root: lora,
606-
Parent: &s.model,
638+
Parent: &parent,
607639
})
608640
}
609641

@@ -625,3 +657,13 @@ func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) {
625657
ctx.Response.Header.SetStatusCode(fasthttp.StatusOK)
626658
ctx.Response.SetBody([]byte("{}"))
627659
}
660+
661+
// getDisplayedModelName returns the model name that must appear in API
662+
// responses. LoRA adapters keep their explicit name, while all base-model
663+
// requests are surfaced as the first alias from --served-model-name.
664+
func (s *VllmSimulator) getDisplayedModelName(reqModel string) string {
665+
if s.isLora(reqModel) {
666+
return reqModel
667+
}
668+
return s.servedModelNames[0]
669+
}

0 commit comments

Comments
 (0)