Skip to content

Commit 27f310a

Browse files
qtnxRobitx
authored andcommitted
feat: handle gpt o1-preview, o1-mini models
1 parent 8ce48fd commit 27f310a

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

lua/gp/dispatcher.lua

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,29 @@ D.prepare_payload = function(messages, model, provider)
184184
model.model = "gpt-4o-2024-05-13"
185185
end
186186

187-
return {
187+
local output = {
188188
model = model.model,
189189
stream = true,
190190
messages = messages,
191191
max_tokens = model.max_tokens or 4096,
192192
temperature = math.max(0, math.min(2, model.temperature or 1)),
193193
top_p = math.max(0, math.min(1, model.top_p or 1)),
194194
}
195+
196+
if provider == "openai" and model.model:sub(1, 2) == "o1" then
197+
for i = #messages, 1, -1 do
198+
if messages[i].role == "system" then
199+
table.remove(messages, i)
200+
end
201+
end
202+
-- remove max_tokens, top_p, temperature for o1 models. https://platform.openai.com/docs/guides/reasoning/beta-limitations
203+
output.max_tokens = nil
204+
output.temperature = nil
205+
output.top_p = nil
206+
output.stream = false
207+
end
208+
209+
return output
195210
end
196211

197212
-- gpt query
@@ -268,6 +283,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
268283
end
269284
end
270285

286+
271287
if content and type(content) == "string" then
272288
qt.response = qt.response .. content
273289
handler(qid, content)
@@ -301,6 +317,19 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
301317
if #buffer > 0 then
302318
process_lines(buffer)
303319
end
320+
local raw_response = qt.raw_response
321+
local content = qt.response
322+
if qt.provider == 'openai' and content == "" and raw_response:match('choices') and raw_response:match("content") then
323+
local response = vim.json.decode(raw_response)
324+
if response.choices and response.choices[1] and response.choices[1].message and response.choices[1].message.content then
325+
content = response.choices[1].message.content
326+
end
327+
if content and type(content) == "string" then
328+
qt.response = qt.response .. content
329+
handler(qid, content)
330+
end
331+
end
332+
304333

305334
if qt.response == "" then
306335
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
@@ -382,7 +411,8 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
382411
}
383412
end
384413

385-
local temp_file = D.query_dir .. "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
414+
local temp_file = D.query_dir ..
415+
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
386416
helpers.table_to_file(payload, temp_file)
387417

388418
local curl_params = vim.deepcopy(D.config.curl_params or {})

0 commit comments

Comments
 (0)