|
1 |
| ----@diagnostic disable |
2 |
| ---- Small async library for Neovim plugins |
3 |
| - |
4 |
| -local function validate_callback(func, callback) |
5 |
| - if callback and type(callback) ~= 'function' then |
6 |
| - local info = debug.getinfo(func, 'nS') |
7 |
| - error( |
8 |
| - string.format( |
9 |
| - 'Callback is not a function for %s, got: %s', |
10 |
| - info.short_src .. ':' .. info.linedefined, |
11 |
| - vim.inspect(callback) |
12 |
| - ) |
13 |
| - ) |
14 |
| - end |
| 1 | +-- Copied from: <https://github.com/neovim/neovim/issues/19624#issuecomment-1202405058> |
| 2 | + |
| 3 | +local co = coroutine |
| 4 | + |
| 5 | +local async_thread = { |
| 6 | + threads = {}, |
| 7 | +} |
| 8 | + |
| 9 | +local function threadtostring(x) |
| 10 | + if jit then |
| 11 | + return string.format('%p', x) |
| 12 | + else |
| 13 | + return tostring(x):match('thread: (.*)') |
| 14 | + end |
15 | 15 | end
|
16 | 16 |
|
17 |
| --- Coroutine.running() was changed between Lua 5.1 and 5.2: |
18 |
| --- - 5.1: Returns the running coroutine, or nil when called by the main thread. |
19 |
| --- - 5.2: Returns the running coroutine plus a boolean, true when the running |
20 |
| --- coroutine is the main one. |
21 |
| --- |
22 |
| --- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT |
23 |
| --- |
24 |
| --- We need to handle both. |
25 |
| -local _main_co_or_nil = coroutine.running() |
26 |
| - |
27 |
| ---- Executes a future with a callback when it is done |
28 |
| ---- @param func function |
29 |
| ---- @param callback function? |
30 |
| ---- @param ... any |
31 |
| -local function run(func, callback, ...) |
32 |
| - validate_callback(func, callback) |
| 17 | +function async_thread.running() |
| 18 | + local thread = co.running() |
| 19 | + local id = threadtostring(thread) |
| 20 | + return async_thread.threads[id] |
| 21 | +end |
33 | 22 |
|
34 |
| - local co = coroutine.create(func) |
| 23 | +function async_thread.create(fn) |
| 24 | + local thread = co.create(fn) |
| 25 | + local id = threadtostring(thread) |
| 26 | + async_thread.threads[id] = true |
| 27 | + return thread |
| 28 | +end |
35 | 29 |
|
36 |
| - local function step(...) |
37 |
| - local ret = { coroutine.resume(co, ...) } |
38 |
| - local stat = ret[1] |
| 30 | +function async_thread.finished(x) |
| 31 | + if co.status(x) == 'dead' then |
| 32 | + local id = threadtostring(x) |
| 33 | + async_thread.threads[id] = nil |
| 34 | + return true |
| 35 | + end |
| 36 | + return false |
| 37 | +end |
39 | 38 |
|
40 |
| - if not stat then |
41 |
| - local err = ret[2] --[[@as string]] |
42 |
| - error( |
43 |
| - string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co)) |
44 |
| - ) |
45 |
| - end |
| 39 | +--- @param async_fn function |
| 40 | +--- @param ... any |
| 41 | +local function execute(async_fn, ...) |
| 42 | + local thread = async_thread.create(async_fn) |
| 43 | + |
| 44 | + local function step(...) |
| 45 | + local ret = { co.resume(thread, ...) } |
| 46 | + local stat, err_or_fn, nargs = unpack(ret) |
| 47 | + |
| 48 | + if not stat then |
| 49 | + error(string.format("The coroutine failed with this message: %s\n%s", |
| 50 | + err_or_fn, debug.traceback(thread))) |
| 51 | + end |
46 | 52 |
|
47 |
| - if coroutine.status(co) == 'dead' then |
48 |
| - if callback then |
49 |
| - callback(unpack(ret, 2, table.maxn(ret))) |
| 53 | + if async_thread.finished(thread) then |
| 54 | + return |
50 | 55 | end
|
51 |
| - return |
52 |
| - end |
53 | 56 |
|
54 |
| - --- @type integer, fun(...: any): any |
55 |
| - local nargs, fn = ret[2], ret[3] |
56 |
| - assert(type(fn) == 'function', 'type error :: expected func') |
| 57 | + assert(type(err_or_fn) == "function", "The 1st parameter must be a lua function") |
57 | 58 |
|
58 |
| - --- @type any[] |
59 |
| - local args = { unpack(ret, 4, table.maxn(ret)) } |
60 |
| - args[nargs] = step |
61 |
| - fn(unpack(args, 1, nargs)) |
62 |
| - end |
| 59 | + local ret_fn = err_or_fn |
| 60 | + local args = { select(4, unpack(ret)) } |
| 61 | + args[nargs] = step |
| 62 | + ret_fn(unpack(args, 1, nargs --[[@as integer]])) |
| 63 | + end |
63 | 64 |
|
64 |
| - step(...) |
| 65 | + step(...) |
65 | 66 | end
|
66 | 67 |
|
67 | 68 | local M = {}
|
68 | 69 |
|
69 |
| ----Use this to create a function which executes in an async context but |
70 |
| ----called from a non-async context. Inherently this cannot return anything |
71 |
| ----since it is non-blocking |
72 |
| ---- @generic F: function |
73 |
| ---- @param argc integer |
74 |
| ---- @param func async F |
75 |
| ---- @return F |
76 |
| -function M.sync(argc, func) |
77 |
| - return function(...) |
78 |
| - assert(not coroutine.running()) |
79 |
| - local callback = select(argc + 1, ...) |
80 |
| - run(func, callback, unpack({ ... }, 1, argc)) |
81 |
| - end |
82 |
| -end |
83 |
| - |
84 |
| ---- @param argc integer |
85 | 70 | --- @param func function
|
86 |
| ---- @param ... any |
87 |
| ---- @return any ... |
88 |
| -function M.wait(argc, func, ...) |
89 |
| - -- Always run the wrapped functions in xpcall and re-raise the error in the |
90 |
| - -- coroutine. This makes pcall work as normal. |
91 |
| - local function pfunc(...) |
92 |
| - local args = { ... } --- @type any[] |
93 |
| - local cb = args[argc] |
94 |
| - args[argc] = function(...) |
95 |
| - cb(true, ...) |
96 |
| - end |
97 |
| - xpcall(func, function(err) |
98 |
| - cb(false, err, debug.traceback()) |
99 |
| - end, unpack(args, 1, argc)) |
100 |
| - end |
101 |
| - |
102 |
| - local ret = { coroutine.yield(argc, pfunc, ...) } |
103 |
| - |
104 |
| - local ok = ret[1] |
105 |
| - if not ok then |
106 |
| - --- @type string, string |
107 |
| - local err, traceback = ret[2], ret[3] |
108 |
| - error(string.format('Wrapped function failed: %s\n%s', err, traceback)) |
109 |
| - end |
110 |
| - |
111 |
| - return unpack(ret, 2, table.maxn(ret)) |
112 |
| -end |
113 |
| - |
114 |
| -function M.run(func, ...) |
115 |
| - return run(func, nil, ...) |
116 |
| -end |
117 |
| - |
118 |
| ---- Creates an async function with a callback style function. |
119 | 71 | --- @param argc integer
|
120 |
| ---- @param func function |
121 | 72 | --- @return function
|
122 |
| -function M.wrap(argc, func) |
123 |
| - assert(type(argc) == 'number') |
124 |
| - assert(type(func) == 'function') |
125 |
| - return function(...) |
126 |
| - return M.wait(argc, func, ...) |
127 |
| - end |
128 |
| -end |
129 |
| - |
130 |
| ---- @generic R |
131 |
| ---- @param n integer Mx number of jobs to run concurrently |
132 |
| ---- @param thunks (fun(cb: function): R)[] |
133 |
| ---- @param interrupt_check fun()? |
134 |
| ---- @param callback fun(ret: R[][]) |
135 |
| -M.join = M.wrap(4, function(n, thunks, interrupt_check, callback) |
136 |
| - n = math.min(n, #thunks) |
137 |
| - |
138 |
| - local ret = {} --- @type any[][] |
139 |
| - |
140 |
| - if #thunks == 0 then |
141 |
| - callback(ret) |
142 |
| - return |
143 |
| - end |
144 |
| - |
145 |
| - local remaining = { unpack(thunks, n + 1) } |
146 |
| - local to_go = #thunks |
147 |
| - |
148 |
| - local function cb(...) |
149 |
| - ret[#ret + 1] = { ... } |
150 |
| - to_go = to_go - 1 |
151 |
| - if to_go == 0 then |
152 |
| - callback(ret) |
153 |
| - elseif not interrupt_check or not interrupt_check() then |
154 |
| - if #remaining > 0 then |
155 |
| - local next_thunk = table.remove(remaining, 1) |
156 |
| - next_thunk(cb) |
| 73 | +M.wrap = function(func, argc) |
| 74 | + return function(...) |
| 75 | + if not async_thread.running() then |
| 76 | + return func(...) |
157 | 77 | end
|
158 |
| - end |
159 |
| - end |
160 |
| - |
161 |
| - for i = 1, n do |
162 |
| - thunks[i](cb) |
163 |
| - end |
164 |
| -end) |
| 78 | + return co.yield(func, argc, ...) |
| 79 | + end |
| 80 | +end |
165 | 81 |
|
166 |
| ----Useful for partially applying arguments to an async function |
167 |
| ---- @param fn function |
168 |
| ---- @param ... any |
| 82 | +--- @param func function |
169 | 83 | --- @return function
|
170 |
| -function M.curry(fn, ...) |
171 |
| - --- @type integer, any[] |
172 |
| - local nargs, args = select('#', ...), { ... } |
173 |
| - |
174 |
| - return function(...) |
175 |
| - local other = { ... } |
176 |
| - for i = 1, select('#', ...) do |
177 |
| - args[nargs + i] = other[i] |
178 |
| - end |
179 |
| - return fn(unpack(args)) |
180 |
| - end |
| 84 | +M.void = function(func) |
| 85 | + return function(...) |
| 86 | + if async_thread.running() then |
| 87 | + return func(...) |
| 88 | + end |
| 89 | + execute(func, ...) |
| 90 | + end |
181 | 91 | end
|
182 | 92 |
|
183 |
| -if vim.schedule then |
184 |
| - --- An async function that when called will yield to the Neovim scheduler to be |
185 |
| - --- able to call the API. |
186 |
| - M.schedule = M.wrap(1, vim.schedule) |
187 |
| -end |
| 93 | +M.schedule = M.wrap(vim.schedule, 1) |
188 | 94 |
|
189 | 95 | return M
|
0 commit comments