Skip to content

Commit 53be515

Browse files
authored
Merge branch 'main' into mistral
2 parents 11aa1f9 + d709147 commit 53be515

File tree

3 files changed

+24
-38
lines changed

3 files changed

+24
-38
lines changed

lib/ruby_llm/aliases.rb

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ class << self
1818
# @param model_id [String] the model identifier or alias
1919
# @param provider_slug [String, Symbol, nil] optional provider to resolve for
2020
# @return [String] the resolved model ID or the original if no alias exists
21-
def resolve(model_id, provider_slug = nil)
22-
provider_aliases = aliases[model_id]
23-
return model_id unless provider_aliases
21+
def resolve(model_id, provider = nil)
22+
return model_id unless aliases[model_id]
2423

25-
if provider_slug
26-
provider_aliases[provider_slug.to_s] || model_id
24+
if provider
25+
aliases[model_id][provider.to_s] || model_id
2726
else
28-
provider_aliases.values.first || model_id
27+
# Get native provider's version
28+
aliases[model_id].values.first || model_id
2929
end
3030
end
3131

lib/ruby_llm/chat.rb

+4-19
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ class Chat
1515

1616
def initialize(model: nil, provider: nil)
1717
model_id = model || RubyLLM.config.default_model
18-
self.model = model_id
19-
self.provider = provider if provider
18+
with_model(model_id, provider: provider)
2019
@temperature = 0.7
2120
@messages = []
2221
@tools = {}
@@ -48,23 +47,9 @@ def with_tools(*tools)
4847
self
4948
end
5049

51-
def model=(model_id)
52-
@model = Models.find model_id
53-
@provider = Models.provider_for model_id
54-
end
55-
56-
def provider=(provider_slug)
57-
@provider = Provider.providers[provider_slug.to_sym] ||
58-
raise(Error, "Unknown provider: #{provider_slug}")
59-
end
60-
61-
def with_provider(provider_slug)
62-
self.provider = provider_slug
63-
self
64-
end
65-
66-
def with_model(model_id)
67-
self.model = model_id
50+
def with_model(model_id, provider: nil)
51+
@model = Models.find model_id, provider
52+
@provider = Provider.providers[@model.provider.to_sym] || raise(Error, "Unknown provider: #{@model.provider}")
6853
self
6954
end
7055

lib/ruby_llm/models.rb

+14-13
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,13 @@ def each(&)
7373
end
7474

7575
# Find a specific model by ID
76-
def find(model_id)
77-
# Try exact match first
78-
exact_match = all.find { |m| m.id == model_id }
79-
return exact_match if exact_match
80-
81-
# Try to resolve via alias
82-
resolved_id = Aliases.resolve(model_id)
83-
if resolved_id != model_id
84-
alias_match = all.find { |m| m.id == resolved_id }
85-
return alias_match if alias_match
86-
end
76+
def find(model_id, provider = nil)
77+
return find_with_provider(model_id, provider) if provider
8778

88-
# Not found
89-
raise ModelNotFoundError, "Unknown model: #{model_id}"
79+
# Find native model
80+
all.find { |m| m.id == model_id } ||
81+
all.find { |m| m.id == Aliases.resolve(model_id) } ||
82+
raise(ModelNotFoundError, "Unknown model: #{model_id}")
9083
end
9184

9285
# Filter to only chat models
@@ -123,5 +116,13 @@ def by_provider(provider)
123116
def refresh!
124117
self.class.refresh!
125118
end
119+
120+
private
121+
122+
def find_with_provider(model_id, provider)
123+
provider_id = Aliases.resolve(model_id, provider)
124+
all.find { |m| m.id == provider_id && m.provider == provider.to_s } ||
125+
raise(ModelNotFoundError, "Unknown model: #{model_id} for provider: #{provider}")
126+
end
126127
end
127128
end

0 commit comments

Comments
 (0)