Skip to content

Commit 10d5088

Browse files
committed
Add P/D support, respond accordingly to doRemotePrefill and doRemoteDecode fields
Signed-off-by: Maya Barnea <mayab@il.ibm.com>
1 parent 9f3d093 commit 10d5088

File tree

5 files changed

+89
-18
lines changed

5 files changed

+89
-18
lines changed

pkg/llm-d-inference-sim/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ type configuration struct {
5353
TimeToFirstToken int `yaml:"time-to-first-token"`
5454
// InterTokenLatency time between generated tokens, in milliseconds
5555
InterTokenLatency int `yaml:"inter-token-latency"`
56+
// KVCacheTransferLatency time to "transfer" kv-cache from another vLLM instance in case P/D is activated, in milliseconds
57+
KVCacheTransferLatency int `yaml:"kv_cache_transfer_latency"`
58+
5659
// Mode defines the simulator response generation mode, valid values: echo, random
5760
Mode string `yaml:"mode"`
5861
// Seed defines random seed for operations
@@ -145,6 +148,9 @@ func (c *configuration) validate() error {
145148
if c.TimeToFirstToken < 0 {
146149
return errors.New("time to first token cannot be negative")
147150
}
151+
if c.KVCacheTransferLatency < 0 {
152+
return errors.New("kv-cache tranfer time cannot be negative")
153+
}
148154
if c.MaxLoras < 1 {
149155
return errors.New("max LoRAs cannot be less than 1")
150156
}

pkg/llm-d-inference-sim/request.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ type completionRequest interface {
4444
getToolChoice() string
4545
// getMaxCompletionTokens returns the maximum completion tokens requested
4646
getMaxCompletionTokens() *int64
47+
// isDoRemoteDecode() returns true is do_remote_decode is true in the request, this means that this is prefill request
48+
doRemoteDecode() bool
49+
// isDoRemotePrefill() returns true is do_remote_prefill is true in the request, this means that this is decode request
50+
doRemotePrefill() bool
4751
}
4852

4953
// baseCompletionRequest contains base completion request related information
@@ -53,7 +57,13 @@ type baseCompletionRequest struct {
5357
// StreamOptions defines streaming options in case Stream is set to true
5458
StreamOptions streamOptions `json:"stream_options"`
5559
// Model defines Model name to use for "inference", could be base Model name or one of available LoRA adapters
56-
Model string `json:"model"`
60+
Model string `json:"model"`
61+
DoRemoteDecode bool `json:"do_remote_decode"`
62+
DoRemotePrefill bool `json:"do_remote_prefill"`
63+
RemoteBlockIds []string `json:"remote_block_ids"`
64+
RemoteEngineId string `json:"remote_engine_id"`
65+
RemoteHost string `json:"remote_host"`
66+
RemotePort int `json:"remote_port"`
5767
}
5868

5969
// StreamOptions defines streaming options for streaming requests
@@ -74,6 +84,14 @@ func (b *baseCompletionRequest) includeUsage() bool {
7484
return !b.Stream || b.StreamOptions.IncludeUsage
7585
}
7686

87+
func (b *baseCompletionRequest) doRemoteDecode() bool {
88+
return b.DoRemoteDecode
89+
}
90+
91+
func (b *baseCompletionRequest) doRemotePrefill() bool {
92+
return b.DoRemotePrefill
93+
}
94+
7795
// completionReqCtx is a context passed in the simulator's flow, it contains the request data needed
7896
// to generate the simulator's response
7997
type completionReqCtx struct {

pkg/llm-d-inference-sim/response.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ type baseCompletionResponse struct {
3737
// Usage contains the token usage statistics for the request
3838
Usage *usage `json:"usage"`
3939
// Object is the Object type, "text_completion", "chat.completion", or "chat.completion.chunk"
40-
Object string `json:"object"`
40+
Object string `json:"object"`
41+
DoRemoteDecode bool `json:"do_remote_decode"`
42+
DoRemotePrefill bool `json:"do_remote_prefill"`
43+
RemoteBlockIds []string `json:"remote_block_ids"`
44+
RemoteEngineId string `json:"remote_engine_id"`
45+
RemoteHost string `json:"remote_host"`
46+
RemotePort int `json:"remote_port"`
4147
}
4248

4349
// usage contains token usage statistics

pkg/llm-d-inference-sim/simulator.go

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ const (
4949
stopFinishReason = "stop"
5050
lengthFinishReason = "length"
5151
toolsFinishReason = "tool_calls"
52+
remoteDecodeFinishReason = "remote_decode"
5253
roleAssistant = "assistant"
5354
roleUser = "user"
5455
textCompletionObject = "text_completion"
@@ -155,6 +156,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
155156
f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
156157
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
157158
f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)")
159+
f.IntVar(&config.KVCacheTransferLatency, "kv_cache_transfer_latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)")
158160
f.Int64Var(&config.Seed, "seed", config.Seed, "Random seed for operations (if not set, current Unix time in nanoseconds is used)")
159161

160162
// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help
@@ -304,6 +306,8 @@ func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion b
304306
var req textCompletionRequest
305307

306308
err := json.Unmarshal(ctx.Request.Body(), &req)
309+
310+
fmt.Printf("Unmarshaled text request: %#v\n", req)
307311
return &req, err
308312
}
309313

@@ -329,6 +333,18 @@ func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) {
329333
s.unloadLora(ctx)
330334
}
331335

336+
func (s *VllmSimulator) validateRequest(req completionRequest) (string, string, int) {
337+
if !s.isValidModel(req.getModel()) {
338+
return fmt.Sprintf("The model `%s` does not exist.", req.getModel()), "NotFoundError", fasthttp.StatusNotFound
339+
}
340+
341+
if req.doRemoteDecode() && req.isStream() {
342+
return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest
343+
}
344+
345+
return "", "", fasthttp.StatusOK
346+
}
347+
332348
// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
333349
func (s *VllmSimulator) isValidModel(model string) bool {
334350
for _, name := range s.config.ServedModelNames {
@@ -365,11 +381,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
365381
return
366382
}
367383

368-
model := vllmReq.getModel()
369-
370-
if !s.isValidModel(model) {
371-
s.sendCompletionError(ctx, fmt.Sprintf("The model `%s` does not exist.", vllmReq.getModel()),
372-
"NotFoundError", fasthttp.StatusNotFound)
384+
errMsg, errType, errCode := s.validateRequest(vllmReq)
385+
if errMsg != "" {
386+
s.sendCompletionError(ctx, errMsg, errType, errCode)
373387
return
374388
}
375389

@@ -477,16 +491,23 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
477491
isChatCompletion: reqCtx.isChatCompletion,
478492
model: displayModel,
479493
},
480-
responseTokens, toolCalls, finishReason, usageDataToSend,
494+
responseTokens, toolCalls, finishReason, usageDataToSend, req.doRemotePrefill(),
481495
)
482496
} else {
497+
if req.doRemoteDecode() {
498+
// in case this is prefill pod processing, return special finish reason
499+
finishReason = remoteDecodeFinishReason
500+
}
501+
483502
s.sendResponse(reqCtx.isChatCompletion,
484503
reqCtx.httpReqCtx,
485504
responseTokens,
486505
toolCalls,
487506
displayModel,
488507
finishReason,
489-
&usageData)
508+
&usageData,
509+
req.doRemoteDecode(),
510+
req.doRemotePrefill())
490511
}
491512
}
492513
reqCtx.wg.Done()
@@ -575,13 +596,25 @@ func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) {
575596
// modelName - display name returned to the client and used in metrics. It is either the first alias
576597
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
577598
func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []toolCall,
578-
finishReason *string, usageData *usage, modelName string) completionResponse {
599+
finishReason *string, usageData *usage, modelName string, doRemoteDecode bool) completionResponse {
579600
baseResp := baseCompletionResponse{
580601
ID: chatComplIDPrefix + uuid.NewString(),
581602
Created: time.Now().Unix(),
582603
Model: modelName,
583604
Usage: usageData,
584605
}
606+
607+
if doRemoteDecode {
608+
// add special fields related to the prefill pod special behavior
609+
baseResp.DoRemoteDecode = true
610+
baseResp.DoRemotePrefill = false
611+
// currently remote prefill information is hard-coded
612+
baseResp.RemoteBlockIds = []string{"DUMMY_ID"}
613+
baseResp.RemoteEngineId = "DUMMY_ID"
614+
baseResp.RemoteHost = "DUMMY"
615+
baseResp.RemotePort = 1234
616+
}
617+
585618
baseChoice := baseResponseChoice{Index: 0, FinishReason: finishReason}
586619

587620
respText := strings.Join(respTokens, "")
@@ -616,8 +649,8 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
616649
// finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools
617650
// usageData - usage (tokens statistics) for this response
618651
func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.RequestCtx, respTokens []string, toolCalls []toolCall,
619-
modelName string, finishReason string, usageData *usage) {
620-
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName)
652+
modelName string, finishReason string, usageData *usage, doRemoteDecode bool, doRemotePrefill bool) {
653+
resp := s.createCompletionResponse(isChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, doRemoteDecode)
621654

622655
data, err := json.Marshal(resp)
623656
if err != nil {
@@ -627,7 +660,7 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
627660

628661
// calculate how long to wait before returning the response, time is based on number of tokens
629662
numOfTokens := usageData.CompletionTokens
630-
totalMillisToWait := s.config.TimeToFirstToken + (numOfTokens-1)*s.config.InterTokenLatency
663+
totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill) + (numOfTokens-1)*s.config.InterTokenLatency
631664
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)
632665

633666
// TODO - maybe add pod id to response header for testing
@@ -638,6 +671,14 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
638671
s.responseSentCallback(modelName)
639672
}
640673

674+
// returns time to first token based on whether
675+
func (s *VllmSimulator) getTimeToFirstToken(doRemotePrefill bool) int {
676+
if doRemotePrefill {
677+
return s.config.KVCacheTransferLatency
678+
}
679+
return s.config.TimeToFirstToken
680+
}
681+
641682
// createModelsResponse creates and returns ModelResponse for the current state, returned array of models contains the base model + LoRA adapters if exist
642683
func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse {
643684
modelsResp := vllmapi.ModelsResponse{Object: "list", Data: []vllmapi.ModelsResponseModelInfo{}}

pkg/llm-d-inference-sim/streaming.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type streamingContext struct {
3838
// response content is wrapped according SSE format
3939
// First token is send after timeToFirstToken milliseconds, every other token is sent after interTokenLatency milliseconds
4040
func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, responseTokens []string, toolCalls []toolCall,
41-
finishReason string, usageData *usage) {
41+
finishReason string, usageData *usage, doRemotePrefill bool) {
4242
context.ctx.SetContentType("text/event-stream")
4343
context.ctx.SetStatusCode(fasthttp.StatusOK)
4444

@@ -57,11 +57,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
5757
if len(toolCalls) > 0 {
5858
s.logger.Info("Going to send tools calls")
5959
for _, tc := range toolCalls {
60-
s.sendTokenChunks(context, w, tc.Function.tokenizedArguments, &tc, finishReason)
60+
s.sendTokenChunks(context, w, tc.Function.tokenizedArguments, &tc, finishReason, doRemotePrefill)
6161
}
6262
} else {
6363
s.logger.Info("Going to send text", "number of tokens", usageData.CompletionTokens)
64-
s.sendTokenChunks(context, w, responseTokens, nil, finishReason)
64+
s.sendTokenChunks(context, w, responseTokens, nil, finishReason, doRemotePrefill)
6565
}
6666
}
6767

@@ -84,9 +84,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
8484
}
8585

8686
// sendTokenChunks creates and sends response chunks
87-
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, tokens []string, tc *toolCall, finishReason string) {
87+
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, tokens []string, tc *toolCall, finishReason string, doRemotePrefill bool) {
8888
// time to first token delay
89-
time.Sleep(time.Duration(s.config.TimeToFirstToken) * time.Millisecond)
89+
time.Sleep(time.Duration(s.getTimeToFirstToken(doRemotePrefill)) * time.Millisecond)
9090

9191
for i, token := range tokens {
9292
if i != 0 {

0 commit comments

Comments
 (0)