Skip to content

Commit 5905fbb

Browse files
authored
Allow Anthropic custom models to override temperature (#18160)
Release Notes: - Allow Anthropic custom models to override "temperature" This also centralized the defaulting of "temperature" to be inside of each model's `into_x` call instead of being sprinkled around the code.
1 parent 7d62fda commit 5905fbb

File tree

12 files changed

+54
-17
lines changed

12 files changed

+54
-17
lines changed

crates/anthropic/src/anthropic.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub enum Model {
4949
/// Indicates whether this custom model supports caching.
5050
cache_configuration: Option<AnthropicModelCacheConfiguration>,
5151
max_output_tokens: Option<u32>,
52+
default_temperature: Option<f32>,
5253
},
5354
}
5455

@@ -124,6 +125,19 @@ impl Model {
124125
}
125126
}
126127

128+
pub fn default_temperature(&self) -> f32 {
129+
match self {
130+
Self::Claude3_5Sonnet
131+
| Self::Claude3Opus
132+
| Self::Claude3Sonnet
133+
| Self::Claude3Haiku => 1.0,
134+
Self::Custom {
135+
default_temperature,
136+
..
137+
} => default_temperature.unwrap_or(1.0),
138+
}
139+
}
140+
127141
pub fn tool_model_id(&self) -> &str {
128142
if let Self::Custom {
129143
tool_override: Some(tool_override),

crates/assistant/src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,7 @@ impl Context {
21802180
messages: Vec::new(),
21812181
tools: Vec::new(),
21822182
stop: Vec::new(),
2183-
temperature: 1.0,
2183+
temperature: None,
21842184
};
21852185
for message in self.messages(cx) {
21862186
if message.status != MessageStatus::Done {

crates/assistant/src/inline_assistant.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2732,7 +2732,7 @@ impl CodegenAlternative {
27322732
messages,
27332733
tools: Vec::new(),
27342734
stop: Vec::new(),
2735-
temperature: 1.,
2735+
temperature: None,
27362736
})
27372737
}
27382738

crates/assistant/src/prompt_library.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ impl PromptLibrary {
796796
}],
797797
tools: Vec::new(),
798798
stop: Vec::new(),
799-
temperature: 1.,
799+
temperature: None,
800800
},
801801
cx,
802802
)

crates/assistant/src/slash_command/auto_command.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async fn commands_for_summaries(
216216
}],
217217
tools: Vec::new(),
218218
stop: Vec::new(),
219-
temperature: 1.0,
219+
temperature: None,
220220
};
221221

222222
while let Some(current_summaries) = stack.pop() {

crates/assistant/src/terminal_inline_assistant.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ impl TerminalInlineAssistant {
284284
messages,
285285
tools: Vec::new(),
286286
stop: Vec::new(),
287-
temperature: 1.0,
287+
temperature: None,
288288
})
289289
}
290290

crates/language_model/src/provider/anthropic.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub struct AvailableModel {
5151
/// Configuration of Anthropic's caching API.
5252
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
5353
pub max_output_tokens: Option<u32>,
54+
pub default_temperature: Option<f32>,
5455
}
5556

5657
pub struct AnthropicLanguageModelProvider {
@@ -200,6 +201,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
200201
}
201202
}),
202203
max_output_tokens: model.max_output_tokens,
204+
default_temperature: model.default_temperature,
203205
},
204206
);
205207
}
@@ -375,8 +377,11 @@ impl LanguageModel for AnthropicModel {
375377
request: LanguageModelRequest,
376378
cx: &AsyncAppContext,
377379
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
378-
let request =
379-
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
380+
let request = request.into_anthropic(
381+
self.model.id().into(),
382+
self.model.default_temperature(),
383+
self.model.max_output_tokens(),
384+
);
380385
let request = self.stream_completion(request, cx);
381386
let future = self.request_limiter.stream(async move {
382387
let response = request.await.map_err(|err| anyhow!(err))?;
@@ -405,6 +410,7 @@ impl LanguageModel for AnthropicModel {
405410
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
406411
let mut request = request.into_anthropic(
407412
self.model.tool_model_id().into(),
413+
self.model.default_temperature(),
408414
self.model.max_output_tokens(),
409415
);
410416
request.tool_choice = Some(anthropic::ToolChoice::Tool {

crates/language_model/src/provider/cloud.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ pub struct AvailableModel {
8787
pub tool_override: Option<String>,
8888
/// Indicates whether this custom model supports caching.
8989
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
90+
/// The default temperature to use for this model.
91+
pub default_temperature: Option<f32>,
9092
}
9193

9294
pub struct CloudLanguageModelProvider {
@@ -255,6 +257,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
255257
min_total_token: config.min_total_token,
256258
}
257259
}),
260+
default_temperature: model.default_temperature,
258261
max_output_tokens: model.max_output_tokens,
259262
}),
260263
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
@@ -516,7 +519,11 @@ impl LanguageModel for CloudLanguageModel {
516519

517520
match &self.model {
518521
CloudModel::Anthropic(model) => {
519-
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
522+
let request = request.into_anthropic(
523+
model.id().into(),
524+
model.default_temperature(),
525+
model.max_output_tokens(),
526+
);
520527
let client = self.client.clone();
521528
let llm_api_token = self.llm_api_token.clone();
522529
let future = self.request_limiter.stream(async move {
@@ -642,8 +649,11 @@ impl LanguageModel for CloudLanguageModel {
642649

643650
match &self.model {
644651
CloudModel::Anthropic(model) => {
645-
let mut request =
646-
request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
652+
let mut request = request.into_anthropic(
653+
model.tool_model_id().into(),
654+
model.default_temperature(),
655+
model.max_output_tokens(),
656+
);
647657
request.tool_choice = Some(anthropic::ToolChoice::Tool {
648658
name: tool_name.clone(),
649659
});

crates/language_model/src/provider/ollama.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ impl OllamaLanguageModel {
235235
options: Some(ChatOptions {
236236
num_ctx: Some(self.model.max_tokens),
237237
stop: Some(request.stop),
238-
temperature: Some(request.temperature),
238+
temperature: request.temperature.or(Some(1.0)),
239239
..Default::default()
240240
}),
241241
tools: vec![],

crates/language_model/src/request.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ pub struct LanguageModelRequest {
236236
pub messages: Vec<LanguageModelRequestMessage>,
237237
pub tools: Vec<LanguageModelRequestTool>,
238238
pub stop: Vec<String>,
239-
pub temperature: f32,
239+
pub temperature: Option<f32>,
240240
}
241241

242242
impl LanguageModelRequest {
@@ -262,7 +262,7 @@ impl LanguageModelRequest {
262262
.collect(),
263263
stream,
264264
stop: self.stop,
265-
temperature: self.temperature,
265+
temperature: self.temperature.unwrap_or(1.0),
266266
max_tokens: max_output_tokens,
267267
tools: Vec::new(),
268268
tool_choice: None,
@@ -290,15 +290,20 @@ impl LanguageModelRequest {
290290
candidate_count: Some(1),
291291
stop_sequences: Some(self.stop),
292292
max_output_tokens: None,
293-
temperature: Some(self.temperature as f64),
293+
temperature: self.temperature.map(|t| t as f64).or(Some(1.0)),
294294
top_p: None,
295295
top_k: None,
296296
}),
297297
safety_settings: None,
298298
}
299299
}
300300

301-
pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
301+
pub fn into_anthropic(
302+
self,
303+
model: String,
304+
default_temperature: f32,
305+
max_output_tokens: u32,
306+
) -> anthropic::Request {
302307
let mut new_messages: Vec<anthropic::Message> = Vec::new();
303308
let mut system_message = String::new();
304309

@@ -400,7 +405,7 @@ impl LanguageModelRequest {
400405
tool_choice: None,
401406
metadata: None,
402407
stop_sequences: Vec::new(),
403-
temperature: Some(self.temperature),
408+
temperature: self.temperature.or(Some(default_temperature)),
404409
top_k: None,
405410
top_p: None,
406411
}

0 commit comments

Comments
 (0)