|
| 1 | +---@diagnostic disable: luadoc-miss-module-name, undefined-doc-name |
| 2 | +--- Small async library for Neovim plugins |
| 3 | +--- @module async |
| 4 | +-- Store all the async threads in a weak table so we don't prevent them from |
| 5 | +-- being garbage collected |
| 6 | +local handles = setmetatable({}, { __mode = "k" }) |
| 7 | +local M = {} |
| 8 | +-- Note: coroutine.running() was changed between Lua 5.1 and 5.2: |
| 9 | +-- - 5.1: Returns the running coroutine, or nil when called by the main thread. |
| 10 | +-- - 5.2: Returns the running coroutine plus a boolean, true when the running |
| 11 | +-- coroutine is the main one. |
| 12 | +-- |
| 13 | +-- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT |
| 14 | +-- |
| 15 | +-- We need to handle both. |
| 16 | +--- Returns whether the current execution context is async. |
| 17 | +--- |
| 18 | +--- @treturn boolean? |
| 19 | +function M.running() |
| 20 | + local current = coroutine.running() |
| 21 | + if current and handles[current] then |
| 22 | + return true |
| 23 | + end |
| 24 | +end |
| 25 | +local function is_Async_T(handle) |
| 26 | + if |
| 27 | + handle |
| 28 | + and type(handle) == "table" |
| 29 | + and vim.is_callable(handle.cancel) |
| 30 | + and vim.is_callable(handle.is_cancelled) |
| 31 | + then |
| 32 | + return true |
| 33 | + end |
| 34 | +end |
| 35 | +local Async_T = {} |
| 36 | +-- Analogous to uv.close |
| 37 | +function Async_T:cancel(cb) |
| 38 | + -- Cancel anything running on the event loop |
| 39 | + if self._current and not self._current:is_cancelled() then |
| 40 | + self._current:cancel(cb) |
| 41 | + end |
| 42 | +end |
| 43 | +function Async_T.new(co) |
| 44 | + local handle = setmetatable({}, { __index = Async_T }) |
| 45 | + handles[co] = handle |
| 46 | + return handle |
| 47 | +end |
| 48 | +-- Analogous to uv.is_closing |
| 49 | +function Async_T:is_cancelled() |
| 50 | + return self._current and self._current:is_cancelled() |
| 51 | +end |
| 52 | +--- Run a function in an async context. |
| 53 | +--- @tparam function func |
| 54 | +--- @tparam function callback |
| 55 | +--- @tparam any ... Arguments for func |
| 56 | +--- @treturn async_t Handle |
| 57 | +function M.run(func, callback, ...) |
| 58 | + vim.validate({ |
| 59 | + func = { func, "function" }, |
| 60 | + callback = { callback, "function", true }, |
| 61 | + }) |
| 62 | + local co = coroutine.create(func) |
| 63 | + local handle = Async_T.new(co) |
| 64 | + local function step(...) |
| 65 | + local ret = { coroutine.resume(co, ...) } |
| 66 | + local ok = ret[1] |
| 67 | + if not ok then |
| 68 | + local err = ret[2] |
| 69 | + error( |
| 70 | + string.format("The coroutine failed with this message:\n%s\n%s", err, debug.traceback(co)) |
| 71 | + ) |
| 72 | + end |
| 73 | + if coroutine.status(co) == "dead" then |
| 74 | + if callback then |
| 75 | + callback(unpack(ret, 4, table.maxn(ret))) |
| 76 | + end |
| 77 | + return |
| 78 | + end |
| 79 | + local nargs, fn = ret[2], ret[3] |
| 80 | + local args = { select(4, unpack(ret)) } |
| 81 | + assert(type(fn) == "function", "type error :: expected func") |
| 82 | + args[nargs] = step |
| 83 | + local r = fn(unpack(args, 1, nargs)) |
| 84 | + if is_Async_T(r) then |
| 85 | + handle._current = r |
| 86 | + end |
| 87 | + end |
| 88 | + step(...) |
| 89 | + return handle |
| 90 | +end |
| 91 | +local function wait(argc, func, ...) |
| 92 | + vim.validate({ |
| 93 | + argc = { argc, "number" }, |
| 94 | + func = { func, "function" }, |
| 95 | + }) |
| 96 | + -- Always run the wrapped functions in xpcall and re-raise the error in the |
| 97 | + -- coroutine. This makes pcall work as normal. |
| 98 | + local function pfunc(...) |
| 99 | + local args = { ... } |
| 100 | + local cb = args[argc] |
| 101 | + args[argc] = function(...) |
| 102 | + cb(true, ...) |
| 103 | + end |
| 104 | + xpcall(func, function(err) |
| 105 | + cb(false, err, debug.traceback()) |
| 106 | + end, unpack(args, 1, argc)) |
| 107 | + end |
| 108 | + local ret = { coroutine.yield(argc, pfunc, ...) } |
| 109 | + local ok = ret[1] |
| 110 | + if not ok then |
| 111 | + local _, err, traceback = unpack(ret) |
| 112 | + error(string.format("Wrapped function failed: %s\n%s", err, traceback)) |
| 113 | + end |
| 114 | + return unpack(ret, 2, table.maxn(ret)) |
| 115 | +end |
| 116 | +--- Wait on a callback style function |
| 117 | +--- |
| 118 | +--- @tparam integer? argc The number of arguments of func. |
| 119 | +--- @tparam function func callback style function to execute |
| 120 | +--- @tparam any ... Arguments for func |
| 121 | +function M.wait(...) |
| 122 | + if type(select(1, ...)) == "number" then |
| 123 | + return wait(...) |
| 124 | + end |
| 125 | + -- Assume argc is equal to the number of passed arguments. |
| 126 | + return wait(select("#", ...) - 1, ...) |
| 127 | +end |
| 128 | +--- Use this to create a function which executes in an async context but |
| 129 | +--- called from a non-async context. Inherently this cannot return anything |
| 130 | +--- since it is non-blocking |
| 131 | +--- @tparam function func |
| 132 | +--- @tparam number argc The number of arguments of func. Defaults to 0 |
| 133 | +--- @tparam boolean strict Error when called in non-async context |
| 134 | +--- @treturn function(...):async_t |
| 135 | +function M.create(func, argc, strict) |
| 136 | + vim.validate({ |
| 137 | + func = { func, "function" }, |
| 138 | + argc = { argc, "number", true }, |
| 139 | + }) |
| 140 | + argc = argc or 0 |
| 141 | + return function(...) |
| 142 | + if M.running() then |
| 143 | + if strict then |
| 144 | + error("This function must run in a non-async context") |
| 145 | + end |
| 146 | + return func(...) |
| 147 | + end |
| 148 | + local callback = select(argc + 1, ...) |
| 149 | + return M.run(func, callback, unpack({ ... }, 1, argc)) |
| 150 | + end |
| 151 | +end |
| 152 | +--- Create a function which executes in an async context but |
| 153 | +--- called from a non-async context. |
| 154 | +--- @tparam function func |
| 155 | +--- @tparam boolean strict Error when called in non-async context |
| 156 | +function M.void(func, strict) |
| 157 | + vim.validate({ func = { func, "function" } }) |
| 158 | + return function(...) |
| 159 | + if M.running() then |
| 160 | + if strict then |
| 161 | + error("This function must run in a non-async context") |
| 162 | + end |
| 163 | + return func(...) |
| 164 | + end |
| 165 | + return M.run(func, nil, ...) |
| 166 | + end |
| 167 | +end |
| 168 | +--- Creates an async function with a callback style function. |
| 169 | +--- |
| 170 | +--- @tparam function func A callback style function to be converted. The last argument must be the callback. |
| 171 | +--- @tparam integer argc The number of arguments of func. Must be included. |
| 172 | +--- @tparam boolean strict Error when called in non-async context |
| 173 | +--- @treturn function Returns an async function |
| 174 | +function M.wrap(func, argc, strict) |
| 175 | + vim.validate({ |
| 176 | + argc = { argc, "number" }, |
| 177 | + }) |
| 178 | + return function(...) |
| 179 | + if not M.running() then |
| 180 | + if strict then |
| 181 | + error("This function must run in an async context") |
| 182 | + end |
| 183 | + return func(...) |
| 184 | + end |
| 185 | + return M.wait(argc, func, ...) |
| 186 | + end |
| 187 | +end |
| 188 | +--- Run a collection of async functions (`thunks`) concurrently and return when |
| 189 | +--- all have finished. |
| 190 | +--- @tparam function[] thunks |
| 191 | +--- @tparam integer n Max number of thunks to run concurrently |
| 192 | +--- @tparam function interrupt_check Function to abort thunks between calls |
| 193 | +function M.join(thunks, n, interrupt_check) |
| 194 | + local function run(finish) |
| 195 | + if #thunks == 0 then |
| 196 | + return finish() |
| 197 | + end |
| 198 | + local remaining = { select(n + 1, unpack(thunks)) } |
| 199 | + local to_go = #thunks |
| 200 | + local ret = {} |
| 201 | + local function cb(...) |
| 202 | + ret[#ret + 1] = { ... } |
| 203 | + to_go = to_go - 1 |
| 204 | + if to_go == 0 then |
| 205 | + finish(ret) |
| 206 | + elseif not interrupt_check or not interrupt_check() then |
| 207 | + if #remaining > 0 then |
| 208 | + local next_task = table.remove(remaining) |
| 209 | + next_task(cb) |
| 210 | + end |
| 211 | + end |
| 212 | + end |
| 213 | + for i = 1, math.min(n, #thunks) do |
| 214 | + thunks[i](cb) |
| 215 | + end |
| 216 | + end |
| 217 | + if not M.running() then |
| 218 | + return run |
| 219 | + end |
| 220 | + return M.wait(1, false, run) |
| 221 | +end |
| 222 | +--- Partially applying arguments to an async function |
| 223 | +--- @tparam function fn |
| 224 | +--- @param ... arguments to apply to `fn` |
| 225 | +function M.curry(fn, ...) |
| 226 | + local args = { ... } |
| 227 | + local nargs = select("#", ...) |
| 228 | + return function(...) |
| 229 | + local other = { ... } |
| 230 | + for i = 1, select("#", ...) do |
| 231 | + args[nargs + i] = other[i] |
| 232 | + end |
| 233 | + fn(unpack(args)) |
| 234 | + end |
| 235 | +end |
| 236 | +--- An async function that when called will yield to the Neovim scheduler to be |
| 237 | +--- able to call the neovim API. |
| 238 | +M.scheduler = M.wrap(vim.schedule, 1, false) |
| 239 | +return M |
0 commit comments