Skip to content

Commit cdb4504

Browse files
committed
Merge remote-tracking branch 'upstream/master' into refactor/flip-completions-mapping
2 parents 5b57388 + 93a611c commit cdb4504

13 files changed

+267
-46
lines changed

.golangci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ linters:
149149
- govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
150150
- ineffassign # Detects when assignments to existing variables are not used
151151
- staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
152-
- typecheck # Like the front-end of a Go compiler, parses and type-checks Go code
153152
- unused # Checks Go code for unused constants, variables, functions and types
154153
## disabled by default
155154
# - asasalint # Check for pass []any as any in variadic func(...any)

chat.go

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ type ChatCompletionMessage struct {
104104
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
105105
Name string `json:"name,omitempty"`
106106

107+
// This property is used for the "reasoning" feature supported by deepseek-reasoner
108+
// which is not in the official documentation.
109+
// the doc from deepseek:
110+
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
111+
ReasoningContent string `json:"reasoning_content,omitempty"`
112+
107113
FunctionCall *FunctionCall `json:"function_call,omitempty"`
108114

109115
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
@@ -119,56 +125,60 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
119125
}
120126
if len(m.MultiContent) > 0 {
121127
msg := struct {
122-
Role string `json:"role"`
123-
Content string `json:"-"`
124-
Refusal string `json:"refusal,omitempty"`
125-
MultiContent []ChatMessagePart `json:"content,omitempty"`
126-
Name string `json:"name,omitempty"`
127-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
128-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
129-
ToolCallID string `json:"tool_call_id,omitempty"`
128+
Role string `json:"role"`
129+
Content string `json:"-"`
130+
Refusal string `json:"refusal,omitempty"`
131+
MultiContent []ChatMessagePart `json:"content,omitempty"`
132+
Name string `json:"name,omitempty"`
133+
ReasoningContent string `json:"reasoning_content,omitempty"`
134+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
135+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
136+
ToolCallID string `json:"tool_call_id,omitempty"`
130137
}(m)
131138
return json.Marshal(msg)
132139
}
133140

134141
msg := struct {
135-
Role string `json:"role"`
136-
Content string `json:"content,omitempty"`
137-
Refusal string `json:"refusal,omitempty"`
138-
MultiContent []ChatMessagePart `json:"-"`
139-
Name string `json:"name,omitempty"`
140-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
141-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
142-
ToolCallID string `json:"tool_call_id,omitempty"`
142+
Role string `json:"role"`
143+
Content string `json:"content,omitempty"`
144+
Refusal string `json:"refusal,omitempty"`
145+
MultiContent []ChatMessagePart `json:"-"`
146+
Name string `json:"name,omitempty"`
147+
ReasoningContent string `json:"reasoning_content,omitempty"`
148+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
149+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
150+
ToolCallID string `json:"tool_call_id,omitempty"`
143151
}(m)
144152
return json.Marshal(msg)
145153
}
146154

147155
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
148156
msg := struct {
149-
Role string `json:"role"`
150-
Content string `json:"content,omitempty"`
151-
Refusal string `json:"refusal,omitempty"`
152-
MultiContent []ChatMessagePart
153-
Name string `json:"name,omitempty"`
154-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
155-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
156-
ToolCallID string `json:"tool_call_id,omitempty"`
157+
Role string `json:"role"`
158+
Content string `json:"content"`
159+
Refusal string `json:"refusal,omitempty"`
160+
MultiContent []ChatMessagePart
161+
Name string `json:"name,omitempty"`
162+
ReasoningContent string `json:"reasoning_content,omitempty"`
163+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
164+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
165+
ToolCallID string `json:"tool_call_id,omitempty"`
157166
}{}
158167

159168
if err := json.Unmarshal(bs, &msg); err == nil {
160169
*m = ChatCompletionMessage(msg)
161170
return nil
162171
}
163172
multiMsg := struct {
164-
Role string `json:"role"`
165-
Content string
166-
Refusal string `json:"refusal,omitempty"`
167-
MultiContent []ChatMessagePart `json:"content"`
168-
Name string `json:"name,omitempty"`
169-
FunctionCall *FunctionCall `json:"function_call,omitempty"`
170-
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
171-
ToolCallID string `json:"tool_call_id,omitempty"`
173+
Role string `json:"role"`
174+
Content string
175+
Refusal string `json:"refusal,omitempty"`
176+
MultiContent []ChatMessagePart `json:"content"`
177+
Name string `json:"name,omitempty"`
178+
ReasoningContent string `json:"reasoning_content,omitempty"`
179+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
180+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
181+
ToolCallID string `json:"tool_call_id,omitempty"`
172182
}{}
173183
if err := json.Unmarshal(bs, &multiMsg); err != nil {
174184
return err
@@ -263,6 +273,8 @@ type ChatCompletionRequest struct {
263273
ReasoningEffort string `json:"reasoning_effort,omitempty"`
264274
// Metadata to store with the completion.
265275
Metadata map[string]string `json:"metadata,omitempty"`
276+
// Configuration for a predicted output.
277+
Prediction *Prediction `json:"prediction,omitempty"`
266278
}
267279

268280
type StreamOptions struct {
@@ -330,6 +342,11 @@ type LogProbs struct {
330342
Content []LogProb `json:"content"`
331343
}
332344

345+
type Prediction struct {
346+
Content string `json:"content"`
347+
Type string `json:"type"`
348+
}
349+
333350
type FinishReason string
334351

335352
const (

chat_stream.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct {
1111
FunctionCall *FunctionCall `json:"function_call,omitempty"`
1212
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
1313
Refusal string `json:"refusal,omitempty"`
14+
15+
// This property is used for the "reasoning" feature supported by deepseek-reasoner
16+
// which is not in the official documentation.
17+
// the doc from deepseek:
18+
// - https://api-docs.deepseek.com/api/create-chat-completion#responses
19+
ReasoningContent string `json:"reasoning_content,omitempty"`
1420
}
1521

1622
type ChatCompletionStreamChoiceLogprobs struct {

chat_stream_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
959959
}
960960
}
961961

962+
func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) {
963+
client, _, _ := setupOpenAITestServer()
964+
965+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
966+
MaxTokens: 100, // This will trigger the validator to fail
967+
Model: openai.O3,
968+
Messages: []openai.ChatCompletionMessage{
969+
{
970+
Role: openai.ChatMessageRoleUser,
971+
Content: "Hello!",
972+
},
973+
},
974+
Stream: true,
975+
})
976+
977+
if stream != nil {
978+
t.Error("Expected nil stream when validation fails")
979+
stream.Close()
980+
}
981+
982+
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
983+
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err)
984+
}
985+
}
986+
987+
func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) {
988+
client, _, _ := setupOpenAITestServer()
989+
990+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
991+
MaxTokens: 100, // This will trigger the validator to fail
992+
Model: openai.O4Mini,
993+
Messages: []openai.ChatCompletionMessage{
994+
{
995+
Role: openai.ChatMessageRoleUser,
996+
Content: "Hello!",
997+
},
998+
},
999+
Stream: true,
1000+
})
1001+
1002+
if stream != nil {
1003+
t.Error("Expected nil stream when validation fails")
1004+
stream.Close()
1005+
}
1006+
1007+
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
1008+
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err)
1009+
}
1010+
}
1011+
9621012
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
9631013
if c1.Index != c2.Index {
9641014
return false

chat_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) {
411411
checks.NoError(t, err, "CreateChatCompletion error")
412412
}
413413

414+
func TestDeepseekR1ModelChatCompletions(t *testing.T) {
415+
client, server, teardown := setupOpenAITestServer()
416+
defer teardown()
417+
server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint)
418+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
419+
Model: "deepseek-reasoner",
420+
MaxCompletionTokens: 100,
421+
Messages: []openai.ChatCompletionMessage{
422+
{
423+
Role: openai.ChatMessageRoleUser,
424+
Content: "Hello!",
425+
},
426+
},
427+
})
428+
checks.NoError(t, err, "CreateChatCompletion error")
429+
}
430+
414431
// TestCompletions Tests the completions endpoint of the API using the mocked server.
415432
func TestChatCompletionsWithHeaders(t *testing.T) {
416433
client, server, teardown := setupOpenAITestServer()
@@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
822839
fmt.Fprintln(w, string(resBytes))
823840
}
824841

842+
func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
843+
var err error
844+
var resBytes []byte
845+
846+
// completions only accepts POST requests
847+
if r.Method != "POST" {
848+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
849+
}
850+
var completionReq openai.ChatCompletionRequest
851+
if completionReq, err = getChatCompletionBody(r); err != nil {
852+
http.Error(w, "could not read request", http.StatusInternalServerError)
853+
return
854+
}
855+
res := openai.ChatCompletionResponse{
856+
ID: strconv.Itoa(int(time.Now().Unix())),
857+
Object: "test-object",
858+
Created: time.Now().Unix(),
859+
// would be nice to validate Model during testing, but
860+
// this may not be possible with how much upkeep
861+
// would be required / wouldn't make much sense
862+
Model: completionReq.Model,
863+
}
864+
// create completions
865+
n := completionReq.N
866+
if n == 0 {
867+
n = 1
868+
}
869+
if completionReq.MaxCompletionTokens == 0 {
870+
completionReq.MaxCompletionTokens = 1000
871+
}
872+
for i := 0; i < n; i++ {
873+
reasoningContent := "User says hello! And I need to reply"
874+
completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent))
875+
res.Choices = append(res.Choices, openai.ChatCompletionChoice{
876+
Message: openai.ChatCompletionMessage{
877+
Role: openai.ChatMessageRoleAssistant,
878+
ReasoningContent: reasoningContent,
879+
Content: completionStr,
880+
},
881+
Index: i,
882+
})
883+
}
884+
inputTokens := numTokens(completionReq.Messages[0].Content) * n
885+
completionTokens := completionReq.MaxTokens * n
886+
res.Usage = openai.Usage{
887+
PromptTokens: inputTokens,
888+
CompletionTokens: completionTokens,
889+
TotalTokens: inputTokens + completionTokens,
890+
}
891+
resBytes, _ = json.Marshal(res)
892+
w.Header().Set(xCustomHeader, xCustomHeaderValue)
893+
for k, v := range rateLimitHeaders {
894+
switch val := v.(type) {
895+
case int:
896+
w.Header().Set(k, strconv.Itoa(val))
897+
default:
898+
w.Header().Set(k, fmt.Sprintf("%s", v))
899+
}
900+
}
901+
fmt.Fprintln(w, string(resBytes))
902+
}
903+
825904
// getChatCompletionBody Returns the body of the request to create a completion.
826905
func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) {
827906
completion := openai.ChatCompletionRequest{}

common.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ type Usage struct {
1313

1414
// CompletionTokensDetails Breakdown of tokens used in a completion.
1515
type CompletionTokensDetails struct {
16-
AudioTokens int `json:"audio_tokens"`
17-
ReasoningTokens int `json:"reasoning_tokens"`
16+
AudioTokens int `json:"audio_tokens"`
17+
ReasoningTokens int `json:"reasoning_tokens"`
18+
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
19+
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
1820
}
1921

2022
// PromptTokensDetails Breakdown of tokens used in the prompt.

completion.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@ const (
2121
O1Preview20240912 = "o1-preview-2024-09-12"
2222
O1 = "o1"
2323
O120241217 = "o1-2024-12-17"
24+
O3 = "o3"
25+
O320250416 = "o3-2025-04-16"
2426
O3Mini = "o3-mini"
2527
O3Mini20250131 = "o3-mini-2025-01-31"
28+
O4Mini = "o4-mini"
29+
O4Mini2020416 = "o4-mini-2025-04-16"
2630
GPT432K0613 = "gpt-4-32k-0613"
2731
GPT432K0314 = "gpt-4-32k-0314"
2832
GPT432K = "gpt-4-32k"
@@ -202,6 +206,8 @@ type CompletionRequest struct {
202206
Temperature float32 `json:"temperature,omitempty"`
203207
TopP float32 `json:"top_p,omitempty"`
204208
User string `json:"user,omitempty"`
209+
// Options for streaming response. Only set this when you set stream: true.
210+
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
205211
}
206212

207213
// CompletionChoice represents one of possible completions.

completion_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) {
3333
}
3434
}
3535

36+
// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported.
37+
func TestCompletionsWrongModelO3(t *testing.T) {
38+
config := openai.DefaultConfig("whatever")
39+
config.BaseURL = "http://localhost/v1"
40+
client := openai.NewClientWithConfig(config)
41+
42+
_, err := client.CreateCompletion(
43+
context.Background(),
44+
openai.CompletionRequest{
45+
MaxTokens: 5,
46+
Model: openai.O3,
47+
},
48+
)
49+
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
50+
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err)
51+
}
52+
}
53+
54+
// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported.
55+
func TestCompletionsWrongModelO4Mini(t *testing.T) {
56+
config := openai.DefaultConfig("whatever")
57+
config.BaseURL = "http://localhost/v1"
58+
client := openai.NewClientWithConfig(config)
59+
60+
_, err := client.CreateCompletion(
61+
context.Background(),
62+
openai.CompletionRequest{
63+
MaxTokens: 5,
64+
Model: openai.O4Mini,
65+
},
66+
)
67+
if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
68+
t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err)
69+
}
70+
}
71+
3672
func TestCompletionWithStream(t *testing.T) {
3773
config := openai.DefaultConfig("whatever")
3874
client := openai.NewClientWithConfig(config)

0 commit comments

Comments
 (0)