Skip to content

Commit 69b1536

Browse files
misbernermalte-prophet
authored andcommitted
Add a TemperatureOpt field that can be set to explicit zero.
1 parent 93a611c commit 69b1536

File tree

3 files changed

+153
-11
lines changed

3 files changed

+153
-11
lines changed

chat.go

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,21 @@ type ChatCompletionRequest struct {
232232
MaxTokens int `json:"max_tokens,omitempty"`
233233
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
234234
// including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning
235-
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
236-
Temperature float32 `json:"temperature,omitempty"`
237-
TopP float32 `json:"top_p,omitempty"`
238-
N int `json:"n,omitempty"`
239-
Stream bool `json:"stream,omitempty"`
240-
Stop []string `json:"stop,omitempty"`
241-
PresencePenalty float32 `json:"presence_penalty,omitempty"`
242-
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
243-
Seed *int `json:"seed,omitempty"`
244-
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
235+
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
236+
237+
// Deprecated: Use TemperatureOpt instead. When TemperatureOpt is set, Temperature is ignored
238+
// regardless of its value. Otherwise (if TemperatureOpt is nil), Temperature is used when
239+
// non-zero.
240+
Temperature float32 `json:"-"`
241+
TemperatureOpt *float32 `json:"temperature,omitempty"`
242+
TopP float32 `json:"top_p,omitempty"`
243+
N int `json:"n,omitempty"`
244+
Stream bool `json:"stream,omitempty"`
245+
Stop []string `json:"stop,omitempty"`
246+
PresencePenalty float32 `json:"presence_penalty,omitempty"`
247+
ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"`
248+
Seed *int `json:"seed,omitempty"`
249+
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
245250
// LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string.
246251
// incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}`
247252
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias
@@ -277,6 +282,42 @@ type ChatCompletionRequest struct {
277282
Prediction *Prediction `json:"prediction,omitempty"`
278283
}
279284

285+
func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error {
286+
type plainChatCompletionRequest ChatCompletionRequest
287+
if err := json.Unmarshal(data, (*plainChatCompletionRequest)(r)); err != nil {
288+
return err
289+
}
290+
if r.TemperatureOpt != nil {
291+
r.Temperature = *r.TemperatureOpt
292+
// Link TemperatureOpt to temperature. This ensures that code modifying temperature
293+
// after unmarshaling (i.e., when TemperatureOpt might be set) will continue to
294+
// work correctly.
295+
r.TemperatureOpt = &r.Temperature
296+
} else if r.Temperature != 0 {
297+
r.TemperatureOpt = &r.Temperature
298+
}
299+
return nil
300+
}
301+
302+
func (r ChatCompletionRequest) MarshalJSON() ([]byte, error) {
303+
type plainChatCompletionRequest ChatCompletionRequest
304+
plainR := plainChatCompletionRequest(r)
305+
if plainR.TemperatureOpt == nil && plainR.Temperature != 0 {
306+
plainR.TemperatureOpt = &plainR.Temperature
307+
}
308+
return json.Marshal(&plainR)
309+
}
310+
311+
func (r *ChatCompletionRequest) GetTemperature() *float32 {
312+
if r.TemperatureOpt != nil {
313+
return r.TemperatureOpt
314+
}
315+
if r.Temperature != 0 {
316+
return &r.Temperature
317+
}
318+
return nil
319+
}
320+
280321
type StreamOptions struct {
281322
// If set, an additional chunk will be streamed before the data: [DONE] message.
282323
// The usage field on this chunk shows the token usage statistics for the entire request,

chat_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
123123
},
124124
expectedError: openai.ErrReasoningModelLimitationsOther,
125125
},
126+
{
127+
name: "set_temperature_unsupported_new",
128+
in: openai.ChatCompletionRequest{
129+
MaxCompletionTokens: 1000,
130+
Model: openai.O1Mini,
131+
Messages: []openai.ChatCompletionMessage{
132+
{
133+
Role: openai.ChatMessageRoleUser,
134+
},
135+
{
136+
Role: openai.ChatMessageRoleAssistant,
137+
},
138+
},
139+
TemperatureOpt: &[]float32{2}[0],
140+
},
141+
expectedError: openai.ErrReasoningModelLimitationsOther,
142+
},
126143
{
127144
name: "set_top_unsupported",
128145
in: openai.ChatCompletionRequest{
@@ -946,3 +963,87 @@ func TestFinishReason(t *testing.T) {
946963
}
947964
}
948965
}
966+
967+
func TestTemperature(t *testing.T) {
968+
tests := []struct {
969+
name string
970+
in openai.ChatCompletionRequest
971+
expectedTemperature *float32
972+
}{
973+
{
974+
name: "not_set",
975+
in: openai.ChatCompletionRequest{},
976+
expectedTemperature: nil,
977+
},
978+
{
979+
name: "set_legacy",
980+
in: openai.ChatCompletionRequest{
981+
Temperature: 0.5,
982+
},
983+
expectedTemperature: &[]float32{0.5}[0],
984+
},
985+
{
986+
name: "set_new",
987+
in: openai.ChatCompletionRequest{
988+
TemperatureOpt: &[]float32{0.5}[0],
989+
},
990+
expectedTemperature: &[]float32{0.5}[0],
991+
},
992+
{
993+
name: "set_both",
994+
in: openai.ChatCompletionRequest{
995+
Temperature: 0.4,
996+
TemperatureOpt: &[]float32{0.5}[0],
997+
},
998+
expectedTemperature: &[]float32{0.5}[0],
999+
},
1000+
}
1001+
1002+
for _, tt := range tests {
1003+
t.Run(tt.name, func(t *testing.T) {
1004+
data, err := json.Marshal(tt.in)
1005+
checks.NoError(t, err, "failed to marshal request to JSON")
1006+
1007+
var req openai.ChatCompletionRequest
1008+
err = json.Unmarshal(data, &req)
1009+
checks.NoError(t, err, "failed to unmarshal request from JSON")
1010+
1011+
temp := req.GetTemperature()
1012+
if tt.expectedTemperature == nil {
1013+
if temp != nil {
1014+
t.Error("expected temperature to be nil")
1015+
}
1016+
} else {
1017+
if temp == nil {
1018+
t.Error("expected temperature to be set")
1019+
} else if *tt.expectedTemperature != *temp {
1020+
t.Errorf("expected temperature to be %v but was %v", *tt.expectedTemperature, *temp)
1021+
}
1022+
}
1023+
})
1024+
}
1025+
}
1026+
1027+
func TestTemperature_ModifyLegacyAfterUnmarshal(t *testing.T) {
1028+
req := openai.ChatCompletionRequest{
1029+
TemperatureOpt: &[]float32{0.5}[0],
1030+
}
1031+
1032+
data, err := json.Marshal(req)
1033+
checks.NoError(t, err, "failed to marshal request to JSON")
1034+
1035+
var req2 openai.ChatCompletionRequest
1036+
err = json.Unmarshal(data, &req2)
1037+
checks.NoError(t, err, "failed to unmarshal request from JSON")
1038+
1039+
if temp := req2.GetTemperature(); temp == nil || *temp != 0.5 {
1040+
t.Errorf("expected temperature to be 0.5 but was %v", temp)
1041+
}
1042+
1043+
// Modify the legacy temperature field
1044+
req2.Temperature = 0.4
1045+
1046+
if temp := req2.GetTemperature(); temp == nil || *temp != 0.4 {
1047+
t.Errorf("expected temperature to be 0.4 but was %v", temp)
1048+
}
1049+
}

reasoning_validator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion
6161
if request.LogProbs {
6262
return ErrReasoningModelLimitationsLogprobs
6363
}
64-
if request.Temperature > 0 && request.Temperature != 1 {
64+
if temp := request.GetTemperature(); temp != nil && *temp != 1 {
6565
return ErrReasoningModelLimitationsOther
6666
}
6767
if request.TopP > 0 && request.TopP != 1 {

0 commit comments

Comments
 (0)