Skip to content

Commit 2ccfb56

Browse files
committed
add request max_tokens/max_completion_tokens
1 parent 9051109 commit 2ccfb56

File tree

9 files changed

+557
-43
lines changed

9 files changed

+557
-43
lines changed

go.mod

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/google/uuid v1.6.0
1111
github.com/onsi/ginkgo/v2 v2.23.4
1212
github.com/onsi/gomega v1.37.0
13+
github.com/openai/openai-go v0.1.0-beta.10
1314
github.com/prometheus/client_golang v1.21.1
1415
github.com/spf13/pflag v1.0.6
1516
github.com/valyala/fasthttp v1.59.0
@@ -30,6 +31,10 @@ require (
3031
github.com/prometheus/client_model v0.6.1 // indirect
3132
github.com/prometheus/common v0.62.0 // indirect
3233
github.com/prometheus/procfs v0.15.1 // indirect
34+
github.com/tidwall/gjson v1.18.0 // indirect
35+
github.com/tidwall/match v1.1.1 // indirect
36+
github.com/tidwall/pretty v1.2.1 // indirect
37+
github.com/tidwall/sjson v1.2.5 // indirect
3338
github.com/valyala/bytebufferpool v1.0.0 // indirect
3439
go.uber.org/automaxprocs v1.6.0 // indirect
3540
golang.org/x/net v0.38.0 // indirect

go.sum

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus
3232
github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8=
3333
github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y=
3434
github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0=
35+
github.com/openai/openai-go v0.1.0-beta.10 h1:CknhGXe8aXQMRuqg255PFnWzgRY9nEryMxoNIBBM9tU=
36+
github.com/openai/openai-go v0.1.0-beta.10/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
3537
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
3638
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
3739
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
@@ -50,6 +52,16 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
5052
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
5153
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
5254
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
55+
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
56+
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
57+
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
58+
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
59+
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
60+
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
61+
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
62+
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
63+
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
64+
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
5365
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
5466
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
5567
github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI=

pkg/vllm-sim/defs.go

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
package vllmsim
2020

2121
import (
22+
"fmt"
2223
"sync"
2324

2425
"github.com/go-logr/logr"
@@ -45,7 +46,7 @@ type VllmSimulator struct {
4546
interTokenLatency int
4647
// port defines on which port the simulator runs
4748
port int
48-
// mode defenes the simulator response generation mode, valid values: echo, random
49+
// mode defines the simulator response generation mode, valid values: echo, random
4950
mode string
5051
// model defines the current base model name
5152
model string
@@ -118,7 +119,7 @@ func (b *baseCompletionRequest) getModel() string {
118119
// completionRequest interface representing both completion request types (text and chat)
119120
type completionRequest interface {
120121
// createResponseText creates and returns response payload based on this request
121-
createResponseText(mode string) string
122+
createResponseText(mode string) (string, error)
122123
// isStream returns boolean that defines is response should be streamed
123124
isStream() bool
124125
// getModel returns model name as defined in the request
@@ -146,6 +147,18 @@ type chatCompletionRequest struct {
146147
baseCompletionRequest
147148
// Messages list of request's Messages
148149
Messages []message `json:"messages"`
150+
151+
// The maximum number of tokens that can be generated in the chat
152+
// completion. This value can be used to control costs for text
153+
// generated via API.
154+
// This value is now deprecated in favor of max_completion_tokens
155+
// and is not compatible with o1 series models.
156+
MaxTokens *int64 `json:"max_tokens"`
157+
158+
// An upper bound for the number of tokens that can be
159+
// generated for a completion, including visible output
160+
// tokens and reasoning tokens.
161+
MaxCompletionTokens *int64 `json:"max_completion_tokens"`
149162
}
150163

151164
// chatCompletionResponse defines structure of /chat/completion response
@@ -168,9 +181,13 @@ type textCompletionRequest struct {
168181
baseCompletionRequest
169182
// Prompt defines request's content
170183
Prompt string `json:"prompt"`
171-
// TODO - do we want to support max tokens?
172-
// MaxTokens is a maximum number of tokens in response
173-
MaxTokens int `json:"max_tokens"`
184+
185+
// The maximum number of [tokens](/tokenizer) that can be generated in the
186+
// completion.
187+
//
188+
// The token count of your prompt plus `max_tokens` cannot exceed the model's
189+
// context length.
190+
MaxTokens *int64 `json:"max_tokens"`
174191
}
175192

176193
// textCompletionResponse defines structure of /completion response
@@ -204,21 +221,48 @@ type chatRespChunkChoice struct {
204221
Delta message `json:"delta"`
205222
}
206223

224+
// returns the max. tokens or error if incorrect
225+
func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
226+
var typeToken string
227+
var tokens *int64
228+
// if both arguments are passed,
229+
// use maxCompletionTokens
230+
// as in the real vllm
231+
if maxCompletionTokens != nil {
232+
tokens = maxCompletionTokens
233+
typeToken = "max_completion_tokens"
234+
} else if maxTokens != nil {
235+
tokens = maxTokens
236+
typeToken = "max_tokens"
237+
}
238+
if tokens != nil && *tokens < 1 {
239+
return nil, fmt.Errorf("%s must be at least 1, got %d", typeToken, *tokens)
240+
}
241+
return tokens, nil
242+
}
243+
207244
// createResponseText creates response text for the given chat completion request and mode
208-
func (req chatCompletionRequest) createResponseText(mode string) string {
245+
func (req chatCompletionRequest) createResponseText(mode string) (string, error) {
246+
maxTokens, err := getMaxTokens(req.MaxCompletionTokens, req.MaxTokens)
247+
if err != nil {
248+
return "", err
249+
}
209250
if mode == modeEcho {
210-
return req.getLastUserMsg()
251+
return getResponseText(maxTokens, req.getLastUserMsg()), nil
211252
}
212-
return getRandomResponseText()
253+
return getRandomResponseText(maxTokens), nil
213254
}
214255

215256
// createResponseText creates response text for the given text completion request and mode
216-
func (req textCompletionRequest) createResponseText(mode string) string {
257+
func (req textCompletionRequest) createResponseText(mode string) (string, error) {
258+
maxTokens, err := getMaxTokens(nil, req.MaxTokens)
259+
if err != nil {
260+
return "", err
261+
}
217262
if mode == modeEcho {
218-
return req.Prompt
219-
} else {
220-
return getRandomResponseText()
263+
return getResponseText(maxTokens, req.Prompt), nil
221264
}
265+
return getRandomResponseText(maxTokens), nil
222266
}
223267

224268
// getLastUserMsg returns last message from this request's messages with user role,

pkg/vllm-sim/metrics.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,18 @@ func (s *VllmSimulator) reportLoras() {
132132

133133
// reportRunningRequests sets information about running completion requests
134134
func (s *VllmSimulator) reportRunningRequests() {
135-
nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs))
136-
s.runningRequests.WithLabelValues(
137-
s.model).Set(float64(nRunningReqs))
135+
if s.runningRequests != nil {
136+
nRunningReqs := atomic.LoadInt64(&(s.nRunningReqs))
137+
s.runningRequests.WithLabelValues(
138+
s.model).Set(float64(nRunningReqs))
139+
}
138140
}
139141

140142
// reportWaitingRequests sets information about waiting completion requests
141143
func (s *VllmSimulator) reportWaitingRequests() {
142-
nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs))
143-
s.waitingRequests.WithLabelValues(
144-
s.model).Set(float64(nWaitingReqs))
144+
if s.waitingRequests != nil {
145+
nWaitingReqs := atomic.LoadInt64(&(s.nWaitingReqs))
146+
s.waitingRequests.WithLabelValues(
147+
s.model).Set(float64(nWaitingReqs))
148+
}
145149
}

pkg/vllm-sim/simulator.go

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"encoding/json"
2323
"fmt"
2424
"net"
25+
"os"
2526
"strings"
2627
"sync"
2728
"sync/atomic"
@@ -37,6 +38,8 @@ import (
3738
"github.com/valyala/fasthttp/fasthttpadaptor"
3839
)
3940

41+
const vLLMDefaultPort = 8000
42+
4043
// New creates a new VllmSimulator instance with the given logger
4144
func New(logger logr.Logger) *VllmSimulator {
4245
return &VllmSimulator{
@@ -63,24 +66,32 @@ func (s *VllmSimulator) Start(ctx context.Context) error {
6366
for i := 1; i <= int(s.maxRunningReqs); i++ {
6467
go s.reqProcessingWorker(ctx, i)
6568
}
69+
listener, err := s.newListener()
70+
if err != nil {
71+
return err
72+
}
73+
6674
// start the http server
67-
return s.startServer()
75+
return s.startServer(listener)
6876
}
6977

7078
// parseCommandParams parses and validates command line parameters
7179
func (s *VllmSimulator) parseCommandParams() error {
72-
pflag.StringVar(&s.mode, "mode", "random", "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
73-
pflag.IntVar(&s.port, "port", 0, "Port")
74-
pflag.IntVar(&s.interTokenLatency, "inter-token-latency", 0, "Time to generate one token (in milliseconds)")
75-
pflag.IntVar(&s.timeToFirstToken, "time-to-first-token", 0, "Time to first token (in milliseconds)")
76-
pflag.StringVar(&s.model, "model", "", "Currently 'loaded' model")
80+
f := pflag.NewFlagSet("vllm-sim flags", pflag.ExitOnError)
81+
f.StringVar(&s.mode, "mode", "random", "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
82+
f.IntVar(&s.port, "port", vLLMDefaultPort, "Port")
83+
f.IntVar(&s.interTokenLatency, "inter-token-latency", 0, "Time to generate one token (in milliseconds)")
84+
f.IntVar(&s.timeToFirstToken, "time-to-first-token", 0, "Time to first token (in milliseconds)")
85+
f.StringVar(&s.model, "model", "", "Currently 'loaded' model")
7786
var lorasStr string
78-
pflag.StringVar(&lorasStr, "lora", "", "List of LoRA adapters, separated by comma")
79-
pflag.IntVar(&s.maxLoras, "max-loras", 1, "Maximum number of LoRAs in a single batch")
80-
pflag.IntVar(&s.maxCpuLoras, "max-cpu-loras", 0, "Maximum number of LoRAs to store in CPU memory")
81-
pflag.Int64Var(&s.maxRunningReqs, "max-running-requests", 5, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
87+
f.StringVar(&lorasStr, "lora", "", "List of LoRA adapters, separated by comma")
88+
f.IntVar(&s.maxLoras, "max-loras", 1, "Maximum number of LoRAs in a single batch")
89+
f.IntVar(&s.maxCpuLoras, "max-cpu-loras", 0, "Maximum number of LoRAs to store in CPU memory")
90+
f.Int64Var(&s.maxRunningReqs, "max-running-requests", 5, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
8291

83-
pflag.Parse()
92+
if err := f.Parse(os.Args[1:]); err != nil {
93+
return err
94+
}
8495

8596
loras := strings.Split(lorasStr, ",")
8697
for _, lora := range loras {
@@ -120,8 +131,17 @@ func (s *VllmSimulator) parseCommandParams() error {
120131
return nil
121132
}
122133

134+
func (s *VllmSimulator) newListener() (net.Listener, error) {
135+
s.logger.Info("Server starting", "port", s.port)
136+
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.port))
137+
if err != nil {
138+
return nil, err
139+
}
140+
return listener, nil
141+
}
142+
123143
// startServer starts http server on port defined in command line
124-
func (s *VllmSimulator) startServer() error {
144+
func (s *VllmSimulator) startServer(listener net.Listener) error {
125145
r := fasthttprouter.New()
126146

127147
// support completion APIs
@@ -141,11 +161,6 @@ func (s *VllmSimulator) startServer() error {
141161
Logger: s,
142162
}
143163

144-
s.logger.Info("Server starting", "port", s.port)
145-
listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.port))
146-
if err != nil {
147-
return err
148-
}
149164
defer func() {
150165
if err := listener.Close(); err != nil {
151166
s.logger.Error(err, "server listener close failed")
@@ -308,12 +323,22 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
308323
atomic.AddInt64(&(s.nRunningReqs), 1)
309324
s.reportRunningRequests()
310325

311-
responseTxt := req.createResponseText(s.mode)
312-
313-
if req.isStream() {
314-
s.sendStreamingResponse(reqCtx.isChatCompletion, reqCtx.httpReqCtx, responseTxt, model)
326+
responseTxt, err := req.createResponseText(s.mode)
327+
if err != nil {
328+
prefix := ""
329+
if reqCtx.isChatCompletion {
330+
prefix = "failed to create chat response"
331+
} else {
332+
prefix = "failed to create text response"
333+
}
334+
s.logger.Error(err, prefix)
335+
reqCtx.httpReqCtx.Error(prefix+err.Error(), fasthttp.StatusBadRequest)
315336
} else {
316-
s.sendResponse(reqCtx.isChatCompletion, reqCtx.httpReqCtx, responseTxt, model)
337+
if req.isStream() {
338+
s.sendStreamingResponse(reqCtx.isChatCompletion, reqCtx.httpReqCtx, responseTxt, model)
339+
} else {
340+
s.sendResponse(reqCtx.isChatCompletion, reqCtx.httpReqCtx, responseTxt, model)
341+
}
317342
}
318343
reqCtx.wg.Done()
319344
}

0 commit comments

Comments
 (0)