Skip to content

Commit 400ddb3

Browse files
authored
feat(remotes): use inputlist on multiple git remotes (#267)
1 parent 296ad98 commit 400ddb3

File tree

4 files changed

+131
-167
lines changed

4 files changed

+131
-167
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ PRs are welcomed for other git host websites!
5454
2. New Features:
5555
- Windows (+wsl2) support.
5656
- Blame support.
57-
- Full [git protocols](https://git-scm.com/book/en/v2/Git-on-the-Server-The-Protocols) support.
5857
- Respect ssh host alias.
5958
- Add `?plain=1` for markdown files.
59+
- Fully customizable [git url](https://git-scm.com/book/en/v2/Git-on-the-Server-The-Protocols) generation.
6060
3. Improvements:
6161
- Use git `stderr` output as error message.
6262
- Async child process IO via coroutine and `uv.spawn`.

lua/gitlinker/commons/async.lua

Lines changed: 71 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,189 +1,95 @@
1-
---@diagnostic disable
2-
--- Small async library for Neovim plugins
3-
4-
local function validate_callback(func, callback)
5-
if callback and type(callback) ~= 'function' then
6-
local info = debug.getinfo(func, 'nS')
7-
error(
8-
string.format(
9-
'Callback is not a function for %s, got: %s',
10-
info.short_src .. ':' .. info.linedefined,
11-
vim.inspect(callback)
12-
)
13-
)
14-
end
1+
-- Copied from: <https://github.com/neovim/neovim/issues/19624#issuecomment-1202405058>
2+
3+
local co = coroutine
4+
5+
local async_thread = {
6+
threads = {},
7+
}
8+
9+
local function threadtostring(x)
10+
if jit then
11+
return string.format('%p', x)
12+
else
13+
return tostring(x):match('thread: (.*)')
14+
end
1515
end
1616

17-
-- Coroutine.running() was changed between Lua 5.1 and 5.2:
18-
-- - 5.1: Returns the running coroutine, or nil when called by the main thread.
19-
-- - 5.2: Returns the running coroutine plus a boolean, true when the running
20-
-- coroutine is the main one.
21-
--
22-
-- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT
23-
--
24-
-- We need to handle both.
25-
local _main_co_or_nil = coroutine.running()
26-
27-
--- Executes a future with a callback when it is done
28-
--- @param func function
29-
--- @param callback function?
30-
--- @param ... any
31-
local function run(func, callback, ...)
32-
validate_callback(func, callback)
17+
function async_thread.running()
18+
local thread = co.running()
19+
local id = threadtostring(thread)
20+
return async_thread.threads[id]
21+
end
3322

34-
local co = coroutine.create(func)
23+
function async_thread.create(fn)
24+
local thread = co.create(fn)
25+
local id = threadtostring(thread)
26+
async_thread.threads[id] = true
27+
return thread
28+
end
3529

36-
local function step(...)
37-
local ret = { coroutine.resume(co, ...) }
38-
local stat = ret[1]
30+
function async_thread.finished(x)
31+
if co.status(x) == 'dead' then
32+
local id = threadtostring(x)
33+
async_thread.threads[id] = nil
34+
return true
35+
end
36+
return false
37+
end
3938

40-
if not stat then
41-
local err = ret[2] --[[@as string]]
42-
error(
43-
string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co))
44-
)
45-
end
39+
--- @param async_fn function
40+
--- @param ... any
41+
local function execute(async_fn, ...)
42+
local thread = async_thread.create(async_fn)
43+
44+
local function step(...)
45+
local ret = { co.resume(thread, ...) }
46+
local stat, err_or_fn, nargs = unpack(ret)
47+
48+
if not stat then
49+
error(string.format("The coroutine failed with this message: %s\n%s",
50+
err_or_fn, debug.traceback(thread)))
51+
end
4652

47-
if coroutine.status(co) == 'dead' then
48-
if callback then
49-
callback(unpack(ret, 2, table.maxn(ret)))
53+
if async_thread.finished(thread) then
54+
return
5055
end
51-
return
52-
end
5356

54-
--- @type integer, fun(...: any): any
55-
local nargs, fn = ret[2], ret[3]
56-
assert(type(fn) == 'function', 'type error :: expected func')
57+
assert(type(err_or_fn) == "function", "The 1st parameter must be a lua function")
5758

58-
--- @type any[]
59-
local args = { unpack(ret, 4, table.maxn(ret)) }
60-
args[nargs] = step
61-
fn(unpack(args, 1, nargs))
62-
end
59+
local ret_fn = err_or_fn
60+
local args = { select(4, unpack(ret)) }
61+
args[nargs] = step
62+
ret_fn(unpack(args, 1, nargs --[[@as integer]]))
63+
end
6364

64-
step(...)
65+
step(...)
6566
end
6667

6768
local M = {}
6869

69-
---Use this to create a function which executes in an async context but
70-
---called from a non-async context. Inherently this cannot return anything
71-
---since it is non-blocking
72-
--- @generic F: function
73-
--- @param argc integer
74-
--- @param func async F
75-
--- @return F
76-
function M.sync(argc, func)
77-
return function(...)
78-
assert(not coroutine.running())
79-
local callback = select(argc + 1, ...)
80-
run(func, callback, unpack({ ... }, 1, argc))
81-
end
82-
end
83-
84-
--- @param argc integer
8570
--- @param func function
86-
--- @param ... any
87-
--- @return any ...
88-
function M.wait(argc, func, ...)
89-
-- Always run the wrapped functions in xpcall and re-raise the error in the
90-
-- coroutine. This makes pcall work as normal.
91-
local function pfunc(...)
92-
local args = { ... } --- @type any[]
93-
local cb = args[argc]
94-
args[argc] = function(...)
95-
cb(true, ...)
96-
end
97-
xpcall(func, function(err)
98-
cb(false, err, debug.traceback())
99-
end, unpack(args, 1, argc))
100-
end
101-
102-
local ret = { coroutine.yield(argc, pfunc, ...) }
103-
104-
local ok = ret[1]
105-
if not ok then
106-
--- @type string, string
107-
local err, traceback = ret[2], ret[3]
108-
error(string.format('Wrapped function failed: %s\n%s', err, traceback))
109-
end
110-
111-
return unpack(ret, 2, table.maxn(ret))
112-
end
113-
114-
function M.run(func, ...)
115-
return run(func, nil, ...)
116-
end
117-
118-
--- Creates an async function with a callback style function.
11971
--- @param argc integer
120-
--- @param func function
12172
--- @return function
122-
function M.wrap(argc, func)
123-
assert(type(argc) == 'number')
124-
assert(type(func) == 'function')
125-
return function(...)
126-
return M.wait(argc, func, ...)
127-
end
128-
end
129-
130-
--- @generic R
131-
--- @param n integer Mx number of jobs to run concurrently
132-
--- @param thunks (fun(cb: function): R)[]
133-
--- @param interrupt_check fun()?
134-
--- @param callback fun(ret: R[][])
135-
M.join = M.wrap(4, function(n, thunks, interrupt_check, callback)
136-
n = math.min(n, #thunks)
137-
138-
local ret = {} --- @type any[][]
139-
140-
if #thunks == 0 then
141-
callback(ret)
142-
return
143-
end
144-
145-
local remaining = { unpack(thunks, n + 1) }
146-
local to_go = #thunks
147-
148-
local function cb(...)
149-
ret[#ret + 1] = { ... }
150-
to_go = to_go - 1
151-
if to_go == 0 then
152-
callback(ret)
153-
elseif not interrupt_check or not interrupt_check() then
154-
if #remaining > 0 then
155-
local next_thunk = table.remove(remaining, 1)
156-
next_thunk(cb)
73+
M.wrap = function(func, argc)
74+
return function(...)
75+
if not async_thread.running() then
76+
return func(...)
15777
end
158-
end
159-
end
160-
161-
for i = 1, n do
162-
thunks[i](cb)
163-
end
164-
end)
78+
return co.yield(func, argc, ...)
79+
end
80+
end
16581

166-
---Useful for partially applying arguments to an async function
167-
--- @param fn function
168-
--- @param ... any
82+
--- @param func function
16983
--- @return function
170-
function M.curry(fn, ...)
171-
--- @type integer, any[]
172-
local nargs, args = select('#', ...), { ... }
173-
174-
return function(...)
175-
local other = { ... }
176-
for i = 1, select('#', ...) do
177-
args[nargs + i] = other[i]
178-
end
179-
return fn(unpack(args))
180-
end
84+
M.void = function(func)
85+
return function(...)
86+
if async_thread.running() then
87+
return func(...)
88+
end
89+
execute(func, ...)
90+
end
18191
end
18292

183-
if vim.schedule then
184-
--- An async function that when called will yield to the Neovim scheduler to be
185-
--- able to call the API.
186-
M.schedule = M.wrap(1, vim.schedule)
187-
end
93+
M.schedule = M.wrap(vim.schedule, 1)
18894

18995
return M

lua/gitlinker/commons/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
26.0.0
1+
27.0.0

lua/gitlinker/git.lua

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
local logging = require("gitlinker.commons.logging")
22
local spawn = require("gitlinker.commons.spawn")
33
local uv = require("gitlinker.commons.uv")
4+
local str = require("gitlinker.commons.str")
45

56
local async = require("gitlinker.async")
67

@@ -358,12 +359,65 @@ local function get_root(cwd)
358359
return result.stdout[1]
359360
end
360361

362+
-- --- NOTE: async functions for `vim.ui.select`.
363+
-- local _run_select = async.wrap(function(remotes, callback)
364+
-- vim.ui.select(remotes, {
365+
-- prompt = "Detect multiple git remotes:",
366+
-- format_item = function(item)
367+
-- return item
368+
-- end,
369+
-- }, function(choice)
370+
-- callback(choice)
371+
-- end)
372+
-- end, 2)
373+
--
374+
-- -- wrap the select function.
375+
-- --- @package
376+
-- --- @type fun(remotes:string[]):string?
377+
-- local function run_select(remotes)
378+
-- return _run_select(remotes)
379+
-- end
380+
381+
--- @package
382+
--- @param remotes string[]
383+
--- @param cwd string?
384+
--- @return string?
385+
local function _select_remotes(remotes, cwd)
386+
local logger = logging.get("gitlinker")
387+
-- local result = run_select(remotes)
388+
389+
local formatted_remotes = { "Please select remote index:" }
390+
for i, remote in ipairs(remotes) do
391+
local remote_url = get_remote_url(remote, cwd)
392+
table.insert(formatted_remotes, string.format("%d. %s (%s)", i, remote, remote_url))
393+
end
394+
395+
async.scheduler()
396+
local result = vim.fn.inputlist(formatted_remotes)
397+
-- logger:debug(string.format("inputlist:%s(%s)", vim.inspect(result), vim.inspect(type(result))))
398+
399+
if type(result) ~= "number" or result < 1 or result > #remotes then
400+
logger:err("fatal: user cancelled multiple git remotes")
401+
return nil
402+
end
403+
404+
for i, remote in ipairs(remotes) do
405+
if result == i then
406+
return remote
407+
end
408+
end
409+
410+
logger:err("fatal: user cancelled multiple git remotes, please select an index")
411+
return nil
412+
end
413+
361414
--- @param cwd string?
362415
--- @return string?
363416
local function get_branch_remote(cwd)
364417
local logger = logging.get("gitlinker")
365418
-- origin/upstream
366419
local remotes = _get_remote(cwd)
420+
logger:debug(string.format("git remotes:%s", vim.inspect(remotes)))
367421
if not remotes then
368422
return nil
369423
end
@@ -372,6 +426,10 @@ local function get_branch_remote(cwd)
372426
return remotes[1]
373427
end
374428

429+
if #remotes > 1 then
430+
return _select_remotes(remotes, cwd)
431+
end
432+
375433
-- origin/linrongbin16/add-rule2
376434
local upstream_branch = _get_rev_name("@{u}", cwd)
377435
if not upstream_branch then

0 commit comments

Comments
 (0)