Skip to content

Commit 07e22e2

Browse files
authored
feat(provider): advanced model caching to prevent fetching models every time.
* Remove deprecated tests * Remove providers for the sake of config provider setup * Current state of providers v2 * Make more robust * docs: update provider examples in readme * Remove thinking stuff and prepare for new solution * Fix test and bug * Attempt to fix pipeline * Rename provider - fix tests - add new checks to verify provider configs * Polish multi provider * feat: online model caching with fallback to pre-defined models * Add model cache expiration config to readme * Move import to top * Use logger and add todo
1 parent ad1e873 commit 07e22e2

File tree

9 files changed

+871
-50
lines changed

9 files changed

+871
-50
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ to consider a visual selection within an API request.
373373
-- if false it also frees up the buffer cursor for further editing elsewhere
374374
command_auto_select_response = true,
375375

376+
-- Time in hours until the model cache is refreshed
377+
-- Set to 0 to deactive model caching
378+
model_cache_expiry_hours = 48,
379+
376380
-- fzf_lua options for PrtModel and PrtChatFinder when plugin is installed
377381
fzf_lua_opts = {
378382
["--ansi"] = true,

lua/parrot/chat_handler.lua

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,10 +1176,19 @@ function ChatHandler:model(params)
11761176
local has_fzf, fzf_lua = pcall(require, "fzf-lua")
11771177
local has_telescope, telescope = pcall(require, "telescope")
11781178

1179+
-- Get models with caching support
1180+
local models
1181+
if prov:online_model_fetching() and self.options.model_cache_expiry_hours > 0 then
1182+
local spinner = self.options.enable_spinner and Spinner:new(self.options.spinner_type) or nil
1183+
models = prov:get_available_models_cached(self.state, self.options.model_cache_expiry_hours, spinner)
1184+
else
1185+
models = prov:get_available_models()
1186+
end
1187+
11791188
if model_name ~= "" then
11801189
self:switch_model(is_chat, model_name, prov)
11811190
elseif has_fzf then
1182-
fzf_lua.fzf_exec(prov:get_available_models(), {
1191+
fzf_lua.fzf_exec(models, {
11831192
prompt = "Model selection ❯",
11841193
fzf_opts = self.options.fzf_lua_opts,
11851194
actions = {
@@ -1203,7 +1212,7 @@ function ChatHandler:model(params)
12031212
.new({}, {
12041213
prompt_title = "Model selection",
12051214
finder = finders.new_table({
1206-
results = prov:get_available_models(),
1215+
results = models,
12071216
}),
12081217
sorter = sorters.values.generic_sorter({}),
12091218
attach_mappings = function(_, map)
@@ -1225,7 +1234,7 @@ function ChatHandler:model(params)
12251234
})
12261235
:find()
12271236
else
1228-
vim.ui.select(prov:get_available_models(), {
1237+
vim.ui.select(models, {
12291238
prompt = "Select your model:",
12301239
}, function(selected_model)
12311240
self:switch_model(is_chat, selected_model, prov)

lua/parrot/config.lua

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
local ChatHandler = require("parrot.chat_handler")
22
local init_provider = require("parrot.provider").init_provider
3+
local utils = require("parrot.utils")
4+
local Spinner = require("parrot.spinner")
5+
local State = require("parrot.state")
36

47
local M = {
58
ui = require("parrot.ui"),
@@ -56,6 +59,7 @@ local defaults = {
5659
style_popup_max_width = 160,
5760
command_prompt_prefix_template = "🤖 {{llm}} ~ ",
5861
command_auto_select_response = true,
62+
model_cache_expiry_hours = 48,
5963
fzf_lua_opts = {
6064
["--ansi"] = true,
6165
["--sort"] = "",
@@ -282,17 +286,46 @@ function M.setup(opts)
282286

283287
M.available_providers = vim.tbl_keys(M.providers)
284288

289+
-- Initialize state early to enable caching
290+
local temp_state = State:new(M.options.state_dir)
291+
292+
-- Clean up cache for removed providers
293+
temp_state:cleanup_cache(M.available_providers)
294+
285295
local available_models = {}
296+
297+
-- Check each provider individually and fetch models
286298
for _, prov_name in ipairs(M.available_providers) do
287299
-- Create the new provider config format
288300
local provider_config = vim.tbl_deep_extend("force", {
289301
name = prov_name,
290302
}, M.providers[prov_name])
291303
local _prov = init_provider(provider_config)
292304

293-
-- do not make an API call on startup
294-
available_models[prov_name] = _prov.models -- or _prov:get_available_models()
305+
-- Use cached model fetching if provider has model_endpoint
306+
if _prov:online_model_fetching() and M.options.model_cache_expiry_hours > 0 then
307+
-- Check cache validity for this specific provider
308+
local endpoint_hash = utils.generate_endpoint_hash(_prov)
309+
local needs_update = not temp_state:is_cache_valid(prov_name, M.options.model_cache_expiry_hours, endpoint_hash)
310+
311+
-- Show spinner only for this provider if needed
312+
local spinner = nil
313+
if needs_update and M.options.enable_spinner then
314+
spinner = Spinner:new(M.options.spinner_type)
315+
M.logger.info("Updating model cache for " .. prov_name)
316+
end
317+
318+
available_models[prov_name] =
319+
_prov:get_available_models_cached(temp_state, M.options.model_cache_expiry_hours, spinner)
320+
else
321+
-- Fall back to static models for providers without model_endpoint
322+
available_models[prov_name] = _prov.models
323+
end
295324
end
325+
326+
-- Now refresh the state with all available models
327+
temp_state:refresh(M.available_providers, available_models)
328+
296329
M.available_models = available_models
297330

298331
table.sort(M.available_providers)
@@ -382,7 +415,12 @@ M.add_default_commands = function(commands, hooks, options)
382415
return completions[cmd]
383416
end
384417
if cmd == "Model" then
385-
return M.available_models[M.chat_handler.state:get_provider()]
418+
-- TODO: Should detect the respective mode --
419+
local current_provider = M.chat_handler.state:get_provider(true) -- Use chat provider by default
420+
if current_provider and M.available_models[current_provider] then
421+
return M.available_models[current_provider]
422+
end
423+
return {}
386424
elseif cmd == "Provider" then
387425
return M.available_providers
388426
end

lua/parrot/provider/multi_provider.lua

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ function MultiProvider:new(config)
238238
return self
239239
end
240240

241+
function MultiProvider:online_model_fetching()
242+
return self.model_endpoint and self.model_endpoint ~= ""
243+
end
244+
241245
-- Validates the MultiProvider configuration
242246
function MultiProvider:validate_config()
243247
local logger = require("parrot.logger")
@@ -264,7 +268,7 @@ function MultiProvider:validate_config()
264268
end
265269

266270
-- Validate model endpoint format if provided (allow functions)
267-
if self.model_endpoint and self.model_endpoint ~= "" then
271+
if self:online_model_fetching() then
268272
if type(self.model_endpoint) == "string" and not self.model_endpoint:match("^https?://") then
269273
logger.error(vim.inspect({
270274
msg = "Invalid model endpoint format for provider",
@@ -411,7 +415,7 @@ end
411415
-- Returns the list of available models
412416
---@return string[]
413417
function MultiProvider:get_available_models()
414-
if self.model_endpoint ~= "" and self:verify() then
418+
if self:online_model_fetching() and self:verify() then
415419
local hdrs = type(self.headers) == "function" and self.headers(self) or (self.headers or {})
416420

417421
-- Handle model_endpoint as function or string/table
@@ -440,4 +444,51 @@ function MultiProvider:get_available_models()
440444
return self.models
441445
end
442446

447+
-- Returns the list of available models with caching support
448+
---@param state table # State object for caching
449+
---@param cache_expiry_hours number # Cache expiry time in hours
450+
---@param spinner table|nil # Optional spinner for loading indication
451+
---@return string[]
452+
function MultiProvider:get_available_models_cached(state, cache_expiry_hours, spinner)
453+
-- Only use caching if model_endpoint is configured
454+
-- otherwise return fallback models
455+
if not self:online_model_fetching() and cache_expiry_hours == 0 then
456+
return self.models
457+
end
458+
459+
-- Generate endpoint hash for cache validation
460+
local endpoint_hash = utils.generate_endpoint_hash(self)
461+
462+
-- Try to get from cache first
463+
local cached_models = state:get_cached_models(self.name, cache_expiry_hours, endpoint_hash)
464+
if cached_models then
465+
return cached_models
466+
end
467+
468+
-- Cache miss or expired - fetch fresh models
469+
if spinner then
470+
spinner:start("Fetching models for " .. self.name .. "...")
471+
end
472+
473+
local fresh_models = self:get_available_models()
474+
475+
if spinner then
476+
spinner:stop()
477+
end
478+
479+
-- Ensure we always have models - fallback to static if fresh fetch failed
480+
if not fresh_models or #fresh_models == 0 then
481+
fresh_models = self.models
482+
end
483+
484+
-- Cache the fresh models if we successfully fetched from API and they differ from static
485+
if #fresh_models > 0 and not vim.deep_equal(fresh_models, self.models) then
486+
state:set_cached_models(self.name, fresh_models, endpoint_hash)
487+
state:save()
488+
end
489+
490+
-- Final safety check - always return at least static models
491+
return fresh_models and #fresh_models > 0 and fresh_models or self.models
492+
end
493+
443494
return MultiProvider

lua/parrot/state.lua

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,22 @@ function State:init_file_state(available_providers)
2121
self.file_state[prov] = {
2222
chat_model = nil,
2323
command_model = nil,
24+
cached_models = {},
2425
}
2526
end
27+
else
28+
-- Ensure existing providers have cached_models initialized
29+
for _, prov in ipairs(available_providers) do
30+
if self.file_state[prov] then
31+
self.file_state[prov].cached_models = self.file_state[prov].cached_models or {}
32+
else
33+
self.file_state[prov] = {
34+
chat_model = nil,
35+
command_model = nil,
36+
cached_models = {},
37+
}
38+
end
39+
end
2640
end
2741
self.file_state.current_provider = self.file_state.current_provider or { chat = nil, command = nil }
2842
end
@@ -33,12 +47,23 @@ end
3347
function State:init_state(available_providers, available_models)
3448
self._state.current_provider = self._state.current_provider or { chat = nil, command = nil }
3549
for _, provider in ipairs(available_providers) do
36-
self._state[provider] = self._state[provider] or {
37-
chat_model = nil,
38-
command_model = nil,
39-
}
40-
self:load_models(provider, "chat_model", available_models)
41-
self:load_models(provider, "command_model", available_models)
50+
self._state[provider] = self._state[provider]
51+
or {
52+
chat_model = nil,
53+
command_model = nil,
54+
cached_models = {},
55+
}
56+
57+
-- Copy cached_models from file_state if they exist
58+
if self.file_state[provider] and self.file_state[provider].cached_models then
59+
self._state[provider].cached_models = self.file_state[provider].cached_models
60+
end
61+
62+
-- Only load models if the provider has available models
63+
if available_models[provider] and #available_models[provider] > 0 then
64+
self:load_models(provider, "chat_model", available_models)
65+
self:load_models(provider, "command_model", available_models)
66+
end
4267
end
4368
end
4469

@@ -47,6 +72,11 @@ end
4772
--- @param model_type string # Type of model (e.g., "chat_model", "command_model").
4873
--- @param available_models table
4974
function State:load_models(provider, model_type, available_models)
75+
-- Ensure provider exists in available_models and has models
76+
if not available_models[provider] or not available_models[provider][1] then
77+
return
78+
end
79+
5080
local state_model = self.file_state and self.file_state[provider] and self.file_state[provider][model_type]
5181
local is_valid_model = state_model and utils.contains(available_models[provider], state_model)
5282

@@ -85,6 +115,13 @@ end
85115

86116
--- Saves the current state to the state file.
87117
function State:save()
118+
-- Merge cached_models from file_state into _state before saving
119+
for provider, data in pairs(self.file_state) do
120+
if type(data) == "table" and data.cached_models and self._state[provider] then
121+
self._state[provider].cached_models = data.cached_models
122+
end
123+
end
124+
88125
futils.table_to_file(self._state, self.state_file)
89126
end
90127

@@ -141,4 +178,108 @@ function State:get_last_chat()
141178
return self._state.last_chat
142179
end
143180

181+
--- Sets cached models for a provider with timestamp
182+
--- @param provider string # Provider name
183+
--- @param models table # Array of model names
184+
--- @param endpoint_hash string|nil # Hash of the endpoint configuration for validation
185+
function State:set_cached_models(provider, models, endpoint_hash)
186+
-- Ensure provider exists in file_state
187+
if not self.file_state[provider] then
188+
self.file_state[provider] = {
189+
chat_model = nil,
190+
command_model = nil,
191+
cached_models = {},
192+
}
193+
end
194+
195+
-- Ensure cached_models table exists for this provider
196+
self.file_state[provider].cached_models = self.file_state[provider].cached_models or {}
197+
198+
local cache_entry = {
199+
models = models,
200+
timestamp = os.time(),
201+
endpoint_hash = endpoint_hash,
202+
}
203+
204+
self.file_state[provider].cached_models = cache_entry
205+
206+
-- Also sync to _state if it exists for immediate availability
207+
if self._state[provider] then
208+
self._state[provider].cached_models = cache_entry
209+
end
210+
end
211+
212+
--- Gets cached models for a provider if they exist and are valid
213+
--- @param provider string # Provider name
214+
--- @param cache_expiry_hours number # Cache expiry time in hours
215+
--- @param endpoint_hash string|nil # Current endpoint hash for validation
216+
--- @return table|nil # Array of cached model names or nil if cache is invalid/expired
217+
function State:get_cached_models(provider, cache_expiry_hours, endpoint_hash)
218+
if not self.file_state[provider] or not self.file_state[provider].cached_models then
219+
return nil
220+
end
221+
222+
local cached = self.file_state[provider].cached_models
223+
-- If cached_models is empty table, return nil
224+
if not cached.models or not cached.timestamp then
225+
return nil
226+
end
227+
228+
local now = os.time()
229+
local expiry_seconds = cache_expiry_hours * 3600
230+
231+
-- Check if cache is expired
232+
if (now - cached.timestamp) > expiry_seconds then
233+
return nil
234+
end
235+
236+
-- Check if endpoint configuration changed (if hash is provided)
237+
if endpoint_hash and cached.endpoint_hash and cached.endpoint_hash ~= endpoint_hash then
238+
return nil
239+
end
240+
241+
return cached.models
242+
end
243+
244+
--- Checks if cached models are valid for a provider
245+
--- @param provider string # Provider name
246+
--- @param cache_expiry_hours number # Cache expiry time in hours
247+
--- @param endpoint_hash string|nil # Current endpoint hash for validation
248+
--- @return boolean
249+
function State:is_cache_valid(provider, cache_expiry_hours, endpoint_hash)
250+
return self:get_cached_models(provider, cache_expiry_hours, endpoint_hash) ~= nil
251+
end
252+
253+
--- Clears cached models for a provider or all providers
254+
--- @param provider string|nil # Provider name, or nil to clear all caches
255+
function State:clear_cache(provider)
256+
if provider then
257+
-- Clear cache for specific provider
258+
if self.file_state[provider] and self.file_state[provider].cached_models then
259+
self.file_state[provider].cached_models = {}
260+
end
261+
else
262+
-- Clear all caches
263+
for prov_name, prov_data in pairs(self.file_state) do
264+
if type(prov_data) == "table" and prov_data.cached_models then
265+
prov_data.cached_models = {}
266+
end
267+
end
268+
end
269+
end
270+
271+
--- Cleans up cache entries for providers that no longer exist
272+
--- @param available_providers table # Current list of available providers
273+
function State:cleanup_cache(available_providers)
274+
-- Remove entire provider entries that no longer exist
275+
for prov_name, _ in pairs(self.file_state) do
276+
-- Skip special keys like current_provider
277+
if prov_name ~= "current_provider" and prov_name ~= "last_chat" then
278+
if not utils.contains(available_providers, prov_name) then
279+
self.file_state[prov_name] = nil
280+
end
281+
end
282+
end
283+
end
284+
144285
return State

0 commit comments

Comments
 (0)