Skip to content

Commit d3a1078

Browse files
committed
Normalize temperature for o1 and o3 models
1 parent daa0258 commit d3a1078

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

lib/ruby_llm/chat.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Chat
1616
def initialize(model: nil)
1717
model_id = model || RubyLLM.config.default_model
1818
self.model = model_id
19-
@temperature = @model.metadata['family'] == 'o1' ? 1 : 0.7
19+
@temperature = 0.7
2020
@messages = []
2121
@tools = {}
2222
@on = {

lib/ruby_llm/provider.rb

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,18 @@ module Provider
88
# Common functionality for all LLM providers. Implements the core provider
99
# interface so specific providers only need to implement a few key methods.
1010
module Methods # rubocop:disable Metrics/ModuleLength
11-
def complete(messages, tools:, temperature:, model:, &block)
12-
payload = render_payload messages, tools: tools, temperature: temperature, model: model, stream: block_given?
11+
def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable Metrics/MethodLength
12+
normalized_temperature = if capabilities.respond_to?(:normalize_temperature)
13+
capabilities.normalize_temperature(temperature, model)
14+
else
15+
temperature
16+
end
17+
18+
payload = render_payload(messages,
19+
tools: tools,
20+
temperature: normalized_temperature,
21+
model: model,
22+
stream: block_given?)
1323

1424
if block_given?
1525
stream_response payload, &block

lib/ruby_llm/providers/openai/capabilities.rb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,16 @@ def apply_special_formatting(name) # rubocop:disable Metrics/MethodLength
257257
.gsub('Omni Moderation', 'Omni-Moderation')
258258
.gsub('Text Moderation', 'Text-Moderation')
259259
end
260+
261+
def normalize_temperature(temperature, model_id)
262+
if model_id.match?(/o[13]/)
263+
# O1/O3 models always use temperature 1.0
264+
RubyLLM.logger.debug "Model #{model_id} requires temperature=1.0, ignoring provided value"
265+
1.0
266+
else
267+
temperature
268+
end
269+
end
260270
end
261271
end
262272
end

0 commit comments

Comments
 (0)