Skip to content

Commit eca9b51

Browse files
authored
Support minItems and maxItems for array parameters in tools (#67)
* Support minItems and maxItems for array parameters in tools Signed-off-by: Ira <IRAR@il.ibm.com> * Check that minItems is not grater than maxItems. Check the errors properly Signed-off-by: Ira <IRAR@il.ibm.com> --------- Signed-off-by: Ira <IRAR@il.ibm.com>
1 parent a1894f6 commit eca9b51

File tree

4 files changed

+100
-21
lines changed

4 files changed

+100
-21
lines changed

pkg/llm-d-inference-sim/simulator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
400400
if reqCtx.isChatCompletion && req.getToolChoice() != toolChoiceNone && req.getTools() != nil {
401401
toolCalls, finishReason, completionTokens, err = createToolCalls(req.getTools(), req.getToolChoice())
402402
}
403-
if toolCalls == nil {
403+
if toolCalls == nil && err == nil {
404404
// Either no tool calls were defined, or we randomly chose not to create tool calls,
405405
// so we generate a response text.
406406
responseTokens, finishReason, completionTokens, err = req.createResponseText(s.mode)

pkg/llm-d-inference-sim/tools_test.go

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,16 @@ var toolWith3DArray = []openai.ChatCompletionToolParam{
160160
"type": "object",
161161
"properties": map[string]interface{}{
162162
"tensor": map[string]interface{}{
163-
"type": "array",
163+
"type": "array",
164+
"minItems": 2,
164165
"items": map[string]any{
165-
"type": "array",
166+
"type": "array",
167+
"minItems": 0,
168+
"maxItems": 1,
166169
"items": map[string]any{
167-
"type": "array",
168-
"items": map[string]string{"type": "string"},
170+
"type": "array",
171+
"items": map[string]string{"type": "string"},
172+
"maxItems": 3,
169173
},
170174
},
171175
"description": "List of strings",
@@ -177,6 +181,28 @@ var toolWith3DArray = []openai.ChatCompletionToolParam{
177181
},
178182
}
179183

184+
var toolWithWrongMinMax = []openai.ChatCompletionToolParam{
185+
{
186+
Function: openai.FunctionDefinitionParam{
187+
Name: "multiply_numbers",
188+
Description: openai.String("Multiply an array of numbers"),
189+
Parameters: openai.FunctionParameters{
190+
"type": "object",
191+
"properties": map[string]interface{}{
192+
"numbers": map[string]interface{}{
193+
"type": "array",
194+
"items": map[string]string{"type": "number"},
195+
"description": "List of numbers to multiply",
196+
"minItems": 3,
197+
"maxItems": 1,
198+
},
199+
},
200+
"required": []string{"numbers"},
201+
},
202+
},
203+
},
204+
}
205+
180206
var toolWithObjects = []openai.ChatCompletionToolParam{
181207
{
182208
Function: openai.FunctionDefinitionParam{
@@ -525,6 +551,14 @@ var _ = Describe("Simulator for request with tools", func() {
525551
err = json.Unmarshal([]byte(tc.Function.Arguments), &args)
526552
Expect(err).NotTo(HaveOccurred())
527553
Expect(args["tensor"]).ToNot(BeEmpty())
554+
tensor := args["tensor"]
555+
Expect(len(tensor)).To(BeNumerically(">=", 2))
556+
for _, elem := range tensor {
557+
Expect(len(elem)).To(Or(Equal(0), Equal(1)))
558+
for _, inner := range elem {
559+
Expect(len(inner)).To(Or(Equal(1), Equal(2), Equal(3)))
560+
}
561+
}
528562
},
529563
func(mode string) string {
530564
return "mode: " + mode
@@ -536,6 +570,32 @@ var _ = Describe("Simulator for request with tools", func() {
536570
Entry(nil, modeRandom),
537571
)
538572

573+
DescribeTable("array parameter with wrong min and max items, no streaming",
574+
func(mode string) {
575+
ctx := context.TODO()
576+
client, err := startServer(ctx, mode)
577+
Expect(err).NotTo(HaveOccurred())
578+
579+
openaiclient := openai.NewClient(
580+
option.WithBaseURL(baseURL),
581+
option.WithHTTPClient(client))
582+
583+
params := openai.ChatCompletionNewParams{
584+
Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)},
585+
Model: model,
586+
ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")},
587+
Tools: toolWithWrongMinMax,
588+
}
589+
590+
_, err = openaiclient.Chat.Completions.New(ctx, params)
591+
Expect(err).To(HaveOccurred())
592+
},
593+
func(mode string) string {
594+
return "mode: " + mode
595+
},
596+
Entry(nil, modeRandom),
597+
)
598+
539599
DescribeTable("objects, no streaming",
540600
func(mode string) {
541601
ctx := context.TODO()

pkg/llm-d-inference-sim/tools_utils.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,19 @@ func createToolCalls(tools []tool, toolChoice string) ([]toolCall, string, int,
5353
// In case of 'required' at least one tool call has to be created, and we randomly choose
5454
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
5555
// choose the number of calls to be 0, response text will be generated instead of a tool call.
56-
numberOfCalls := randomInt(len(tools), toolChoice == toolChoiceRequired)
56+
min := 0
57+
if toolChoice == toolChoiceRequired {
58+
min = 1
59+
}
60+
numberOfCalls := randomInt(min, len(tools))
5761
if numberOfCalls == 0 {
5862
return nil, "", 0, nil
5963
}
6064

6165
calls := make([]toolCall, 0)
6266
for i := range numberOfCalls {
6367
// Randomly choose which tools to call. We may call the same tool more than once.
64-
index := randomInt(len(tools)-1, false)
68+
index := randomInt(0, len(tools)-1)
6569
args, err := generateToolArguments(tools[index])
6670
if err != nil {
6771
return nil, "", 0, err
@@ -130,7 +134,7 @@ func createArgument(property any) (any, error) {
130134
if ok {
131135
enumArray, ok := enum.([]any)
132136
if ok && len(enumArray) > 0 {
133-
index := randomInt(len(enumArray)-1, false)
137+
index := randomInt(0, len(enumArray)-1)
134138
return enumArray[index], nil
135139
}
136140
}
@@ -139,13 +143,24 @@ func createArgument(property any) (any, error) {
139143
case "string":
140144
return getStringArgument(), nil
141145
case "number":
142-
return randomInt(100, false), nil
146+
return randomInt(0, 100), nil
143147
case "boolean":
144148
return flipCoin(), nil
145149
case "array":
146150
items := propertyMap["items"]
147151
itemsMap := items.(map[string]any)
148-
numberOfElements := randomInt(5, true)
152+
minItems := 1
153+
maxItems := 5
154+
if value, ok := propertyMap["minItems"]; ok {
155+
minItems = int(value.(float64))
156+
}
157+
if value, ok := propertyMap["maxItems"]; ok {
158+
maxItems = int(value.(float64))
159+
}
160+
if minItems > maxItems {
161+
return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems)
162+
}
163+
numberOfElements := randomInt(minItems, maxItems)
149164
array := make([]any, numberOfElements)
150165
for i := range numberOfElements {
151166
elem, err := createArgument(itemsMap)
@@ -177,7 +192,7 @@ func createArgument(property any) (any, error) {
177192
}
178193

179194
func getStringArgument() string {
180-
index := randomInt(len(fakeStringArguments)-1, false)
195+
index := randomInt(0, len(fakeStringArguments)-1)
181196
return fakeStringArguments[index]
182197
}
183198

@@ -336,6 +351,14 @@ const schema = `{
336351
"items": {
337352
"type": "string"
338353
}
354+
},
355+
"minItems": {
356+
"type": "integer",
357+
"minimum": 0
358+
},
359+
"maxItems": {
360+
"type": "integer",
361+
"minimum": 0
339362
}
340363
},
341364
"required": [

pkg/llm-d-inference-sim/utils.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
6262
// getRandomResponseText returns random response text from the pre-defined list of responses
6363
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
6464
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {
65-
index := randomInt(len(chatCompletionFakeResponses)-1, false)
65+
index := randomInt(0, len(chatCompletionFakeResponses)-1)
6666
text := chatCompletionFakeResponses[index]
6767

6868
return getResponseText(maxCompletionTokens, text)
@@ -105,26 +105,22 @@ func randomNumericString(length int) string {
105105
digits := "0123456789"
106106
result := make([]byte, length)
107107
for i := 0; i < length; i++ {
108-
num := randomInt(9, false)
108+
num := randomInt(0, 9)
109109
result[i] = digits[num]
110110
}
111111
return string(result)
112112
}
113113

114-
// Returns an integer between 0 and max (included), unless startFromeOne is true,
115-
// in which case returns an integer between 1 and max (included)
116-
func randomInt(max int, startFromOne bool) int {
114+
// Returns an integer between min and max (included)
115+
func randomInt(min int, max int) int {
117116
src := rand.NewSource(time.Now().UnixNano())
118117
r := rand.New(src)
119-
if startFromOne {
120-
return r.Intn(max) + 1 // [1, max]
121-
}
122-
return r.Intn(max + 1) // [0, max]
118+
return r.Intn(max-min+1) + min
123119
}
124120

125121
// Returns true or false randomly
126122
func flipCoin() bool {
127-
return randomInt(1, false) != 0
123+
return randomInt(0, 1) != 0
128124
}
129125

130126
// Regular expression for the response tokenization

0 commit comments

Comments
 (0)