Skip to content

Commit 8e9b2ac

Browse files
authored
fix: properly unmarshal JSON schema in ChatCompletionResponseFormatJSONSchema.schema (#1028)
* feat: #1027 * add tests * feat: #1027 * feat: #1027 * feat: #1027 * update chat_test.go * feat: #1027 * add test cases
1 parent c125ae2 commit 8e9b2ac

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

chat.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import (
55
"encoding/json"
66
"errors"
77
"net/http"
8+
9+
"github.com/sashabaranov/go-openai/jsonschema"
810
)
911

1012
// Chat message role defined by the OpenAI API.
@@ -221,6 +223,31 @@ type ChatCompletionResponseFormatJSONSchema struct {
221223
Strict bool `json:"strict"`
222224
}
223225

226+
func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error {
227+
type rawJSONSchema struct {
228+
Name string `json:"name"`
229+
Description string `json:"description,omitempty"`
230+
Schema json.RawMessage `json:"schema"`
231+
Strict bool `json:"strict"`
232+
}
233+
var raw rawJSONSchema
234+
if err := json.Unmarshal(data, &raw); err != nil {
235+
return err
236+
}
237+
r.Name = raw.Name
238+
r.Description = raw.Description
239+
r.Strict = raw.Strict
240+
if len(raw.Schema) > 0 && string(raw.Schema) != "null" {
241+
var d jsonschema.Definition
242+
err := json.Unmarshal(raw.Schema, &d)
243+
if err != nil {
244+
return err
245+
}
246+
r.Schema = &d
247+
}
248+
return nil
249+
}
250+
224251
// ChatCompletionRequest represents a request structure for chat completion API.
225252
type ChatCompletionRequest struct {
226253
Model string `json:"model"`

chat_test.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,142 @@ func TestFinishReason(t *testing.T) {
946946
}
947947
}
948948
}
949+
950+
func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) {
951+
type args struct {
952+
data []byte
953+
}
954+
tests := []struct {
955+
name string
956+
args args
957+
wantErr bool
958+
}{
959+
{
960+
"",
961+
args{
962+
data: []byte(`{
963+
"name": "math_response",
964+
"strict": true,
965+
"schema": {
966+
"type": "object",
967+
"properties": {
968+
"steps": {
969+
"type": "array",
970+
"items": {
971+
"type": "object",
972+
"properties": {
973+
"explanation": { "type": "string" },
974+
"output": { "type": "string" }
975+
},
976+
"required": ["explanation","output"],
977+
"additionalProperties": false
978+
}
979+
},
980+
"final_answer": { "type": "string" }
981+
},
982+
"required": ["steps","final_answer"],
983+
"additionalProperties": false
984+
}
985+
}`),
986+
},
987+
false,
988+
},
989+
{
990+
"",
991+
args{
992+
data: []byte(`{
993+
"name": "math_response",
994+
"strict": true,
995+
"schema": null
996+
}`),
997+
},
998+
false,
999+
},
1000+
{
1001+
"",
1002+
args{
1003+
data: []byte(`[123,456]`),
1004+
},
1005+
true,
1006+
},
1007+
{
1008+
"",
1009+
args{
1010+
data: []byte(`{
1011+
"name": "math_response",
1012+
"strict": true,
1013+
"schema": 123456
1014+
}`),
1015+
},
1016+
true,
1017+
},
1018+
}
1019+
for _, tt := range tests {
1020+
t.Run(tt.name, func(t *testing.T) {
1021+
var r openai.ChatCompletionResponseFormatJSONSchema
1022+
err := r.UnmarshalJSON(tt.args.data)
1023+
if (err != nil) != tt.wantErr {
1024+
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
1025+
}
1026+
})
1027+
}
1028+
}
1029+
1030+
func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) {
1031+
type args struct {
1032+
bs []byte
1033+
}
1034+
tests := []struct {
1035+
name string
1036+
args args
1037+
wantErr bool
1038+
}{
1039+
{
1040+
"",
1041+
args{bs: []byte(`{
1042+
"model": "llama3-1b",
1043+
"messages": [
1044+
{ "role": "system", "content": "You are a helpful math tutor." },
1045+
{ "role": "user", "content": "solve 8x + 31 = 2" }
1046+
],
1047+
"response_format": {
1048+
"type": "json_schema",
1049+
"json_schema": {
1050+
"name": "math_response",
1051+
"strict": true,
1052+
"schema": {
1053+
"type": "object",
1054+
"properties": {
1055+
"steps": {
1056+
"type": "array",
1057+
"items": {
1058+
"type": "object",
1059+
"properties": {
1060+
"explanation": { "type": "string" },
1061+
"output": { "type": "string" }
1062+
},
1063+
"required": ["explanation","output"],
1064+
"additionalProperties": false
1065+
}
1066+
},
1067+
"final_answer": { "type": "string" }
1068+
},
1069+
"required": ["steps","final_answer"],
1070+
"additionalProperties": false
1071+
}
1072+
}
1073+
}
1074+
}`)},
1075+
false,
1076+
},
1077+
}
1078+
for _, tt := range tests {
1079+
t.Run(tt.name, func(t *testing.T) {
1080+
var m openai.ChatCompletionRequest
1081+
err := json.Unmarshal(tt.args.bs, &m)
1082+
if err != nil {
1083+
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
1084+
}
1085+
})
1086+
}
1087+
}

0 commit comments

Comments
 (0)