async.lua

---------------------------------------------------------------------------
--- Utilities to work with asynchronous callback-style control flow.
--
-- All callbacks must adhere to the callback signature: `function(err, values...)`.
-- The first parameter is the error value. It will be `nil` if no error ocurred, otherwise
-- the error value depends on the function that the callback was passed to.
-- If no error ocurred, an arbitrary number of return values may be received as second parameter and onward.
--
-- Depending on the particular implementation of a function that a callback is passed to, it may be possible to
-- receive non-`nil` return values, even when an error ocurred. Using such return values should be considered undefined
-- behavior, unless explicitly documented by the calling function.
--
-- @module async
-- @license GPL v3.0
---------------------------------------------------------------------------

local util = require("async.internal.util")
local pack = table.pack or function(...) return {...} end
local unpack = table.unpack or unpack

local async = {}


--- Wraps a function such that it can only ever be called once.
--
-- If the returned function is called multiple times, only the first call will result
-- in the wrapped function being called. Subsequent calls will be ignored.
-- If no function is given, a noop function will be used.
--
-- @tparam[opt] function fn The function to wrap.
-- @treturn function The wrapped function or a noop.
function async.once(fn)
    if not fn then
        fn = function() end
    end

    local ran = false
    return function(...)
        if not ran then
            ran = true
            fn(...)
        end
        -- TODO: Decide if we want to throw/log an error when `ran == true`
    end
end


--- Turns an asynchronous function into a blocking operation.
--
-- Using coroutines, this runs a callback-style asynchronous function and blocks until it completes.
-- The function to be wrapped may only accept a single parameter: a callback function.
-- Return values passed to this callback will be returned as regular values by `wrap_sync`.
--
-- Panics that happened inside the asynchronous function will be captured and re-thrown.
--
-- @tparam function fn An asynchronous function: `function(cb)`
-- @treturn any Any return values as passed by the wrapped function
function async.wrap_sync(fn)
    local co = coroutine.create(function()
        fn(function(...)
            coroutine.yield(...)
        end)
    end)

    local ret = pack(coroutine.resume(co))
    if not ret[1] then
        error(ret[2])
    else
        table.remove(ret, 1)
        return unpack(ret)
    end
end


--- Executes a list of asynchronous functions in series.
--
-- `waterfall` accepts an arbitrary list of asynchronous functions (tasks) and calls them in series.
-- Each function waits for the previous one to finish and will be given the previous function's return values.
--
-- If an error occurs in any task, execution is stopped immediately, and the final callback is called
-- with the error value.
-- If all tasks complete successfully, the final callback will be called with the return values of the last
-- task in the list.
--
-- All tasks must adhere to the callback signature: `function(err, ...)`.
--
-- @async
-- @tparam table tasks The asynchronous tasks to execute in series.
-- @tparam function final_callback Called when all tasks have finished.
-- @treturn any The error returned by a failing task.
-- @treturn any Values as returned by the last task.
function async.waterfall(tasks, final_callback)
    final_callback = async.once(final_callback)

    -- Bail early if there is nothing to do
    if not next(tasks) then
        final_callback()
        return
    end

    local i = 0
    local _run
    local _continue

    _run = function(...)
        i = i + 1
        local task = tasks[i]
        if task then
            local args = pack(...)
            table.insert(args, _continue)
            task(unpack(args))
        else
            -- We've reached the bottom of the waterfall, time to exit
            final_callback(nil, ...)
        end
    end

    _continue = function(err, ...)
        if err then
            final_callback(err)
            return
        end

        _run(...)
    end

    _continue()
end


--- Runs all tasks in parallel and collects the results.
--
-- If any task produces an error, `final_callback` will be called immediately
-- and remaining tasks will not be tracked.
--
-- @async
-- @tparam table tasks A list of asynchronous functions. They will be given a
--  callback parameter: `function(err, ...)`.
-- @tparam function final_callback
function async.all(tasks, final_callback)
    final_callback = async.once(final_callback)

    local len = #tasks
    if len == 0 then
        final_callback()
        return
    end

    local results = {}
    local done = 0
    local cancelled = false

    for i, task in ipairs(tasks) do
        task(function(err, ...)
            if cancelled then
                return
            end

            if err then
                cancelled = true
                final_callback(err)
                return
            end

            done = done + 1
            results[i] = pack(...)

            if done == len then
                final_callback(nil, results)
            end
        end)
    end
end


--- Resolves a DAG (Directed Acyclic Graph) of asynchronous dependencies.
--
-- The task list is a key-value map, where the key defines the task name and the value the is the task definition.
-- A task definition consists of a a list of dependencies (which may be empty) and an asynchronous
-- function.
-- Any task name may be used as dependency for any other task, as long as no loops are created.
-- A task's function will be called once all of its dependencies have become available and will be passed a `results`
-- table that contains the values returned by all tasks so far.
--
-- If any tasks passes an error to its callback, execution and tracking for all other tasks stops and `final_callback`
-- is called with that error value. Otherwise, `final_callback` will be called once all tasks have completed, with the
-- results of all tasks.
--
-- The `results` table uses the task name as key and provides a `table.pack`ed list of task results as value.
--
-- @usage
--    async.dag(
--        {
--            get_data = { function(cb)
--                local f = fs.open("/tmp/foo.txt")
--                f:read(cb)
--            end },
--            make_folder = { function(cb)
--                fs.make_folder("/tmp/bar", cb)
--            end },
--            write_data = { "get_data", "make_folder", function(results, cb)
--                local data = table.unpack(results.get_data)
--                local f = fs.open("/tmp/bar/foo.txt")
--                f:write(data, cb)
--            end },
--        },
--        function(err, results)
--            if err ~= nil then
--                error(err)
--            else
--                print("success")
--            end
--        end
--    )
--
-- @async
-- @tparam table tasks A map of asynchronous tasks.
-- @tparam function final_callback
-- @treturn any Any error from a failing task.
-- @treturn table Results of all resolved tasks.
function async.dag(tasks, final_callback)
    final_callback = async.once(final_callback)

    -- Short-circuit if there is nothing to do.
    -- To provide consistent API, pass a `results` table
    if not next(tasks) then
        final_callback(nil, {})
        return
    end

    local results = {}
    local queue = {}
    local queue_len = 0
    local running = 0
    local pending = {}
    local cancelled = false

    local _run_queue

    local function _enqueue(name, fn)
        if queue[name] then
            error(string.format("task with name '%s' already in queue", name))
            return
        end

        queue[name] = fn
        queue_len = queue_len + 1
        -- When queued for execution, it is no longer waiting for dependencies
        pending[name] = nil
    end

    local function _initialize(name, task)
        -- Short-circuit for tasks without dependencies
        if type(task) == "function" then
            _enqueue(name, task)
            return
        elseif #task == 1 then
            _enqueue(name, task[1])
            return
        end

        local dependencies = util.slice(task, 1, -1)
        local ready = util.all(dependencies, function(name)
            return results[name] ~= nil
        end)

        if ready then
            _enqueue(name, task[#task])
        else
            pending[name] = task
        end
    end

    local function _check_pending(tasks)
        for name, task in pairs(tasks) do
            _initialize(name, task)
        end

        -- When there are tasks waiting for dependencies, but none in the queue
        -- and none actively running, we must have reached a deadlock
        if queue_len == 0 and running == 0 and next(pending) then
            local err = "deadlock detected. the following tasks are waiting for dependencies: "
            for name in pairs(pending) do
                err = err .. string.format(" %s", name)
            end
            error(err)
            return
        end

        _run_queue()
    end

    _run_queue = function()
        -- `pairs` is not thread safe, so to avoid a race condition when this is used
        -- with multi-threaded concurrency, the queue has to be copied.
        local tasks = queue
        queue = {}

        for name, fn in pairs(tasks) do
            tasks[name] = nil
            queue_len = queue_len - 1
            running = running + 1

            fn(results, function(err, ...)
                if cancelled then
                    -- Another, concurrent task already finished with an error
                    return
                elseif err then
                    cancelled = true
                    final_callback(err, results)
                    return
                end

                results[name] = pack(...)
                running = running - 1

                -- If all lists are empty, we must have run all tasks
                if queue_len == 0 and running == 0 and not next(pending) then
                    final_callback(nil, results)
                else
                    _check_pending(pending)
                end
            end)
        end
    end

    _check_pending(tasks)
end


--- Repeatedly calls `test` and `iteratee` until stopped.
--
-- `iteratee` is called repeatedly. It is passed a callback
-- (`function(err, ...)`), which should be called with either an error or any
-- results of the iteration.
--
-- `test` is called once per iteration, after `iteratee`. It is passed a
-- callback (`function(err, stop)`) and any non-error values from `iteratee`.
-- The callback should be called with either an error or a boolean value.
-- Iteration will stop when an error is passed by either callback or when
-- `test` passes a falsy value.
--
-- In either case `final_callback` will be called with the latest results from
-- `iteratee`.
--
-- This is, in concept, analogous to a `do {} while ()` construct, where `iteratee`
-- is the `do` block and `test` is the `while` test.
--
-- @async
-- @tparam function iteratee Called repeatedly. Signature: `function(cb)`.
-- @tparam function test Called once per iteration, after `iteratee`.
--  Signature: `function(..., cb)`.
-- @tparam function final_callback Called once, when `test` indicates to stop
--  the iteration.
-- @treturn any Any error from `iteratee` or `test`.
-- @treturn any Values passed by the most recent execution of `iteratee`.
function async.do_while(iteratee, test, final_callback)
    final_callback = async.once(final_callback)
    local results = {}
    local _next

    -- Wraps `test` to break on errors and capture results, where `results` are what
    -- the `iteratee` passed to its callback.
    local function _test(err, ...)
        if err then
            return final_callback(err)
        end

        results = pack(...)
        local args = pack(...)
        table.insert(args, _next)
        test(unpack(args))
    end

    -- Calls `iteratee` for the next iteration, unless stopped
    _next = function(err, continue)
        if err then
            return final_callback(err)
        end

        if not continue then
            return final_callback(nil, unpack(results))
        end

        iteratee(_test)
    end

    iteratee(_test)
end


--- Wrap a function with arguments for use as callback.
--
-- This may be used to wrap a function or table method as a callback, providing a (partial)
-- argument list.
-- Arguments to this call are passed through to the provided function when it is called,
-- arguments from the final caller are appended after those.
--
-- If the function is actually a method (i.e. it expects a `self` parameter or is called with `:`),
-- the `self` table can be passed as the first argument. Otherwise, `nil` should be passed.
--
-- @todo Optimize the common use cases of only having a few outer arguments
-- by hardcoding those cases.
--
-- @tparam[opt] table object The object to call the method on.
-- @tparam function fn The function to wrap.
-- @tparam any ... Arbitrary arguments to pass through to the wrapped function.
-- @treturn function
function async.callback(object, fn, ...)
    local outer = pack(...)

    return function(...)
        local inner = pack(...)
        -- Merge, then unpack both argument lists to provide a single var arg.
        local args = { object }
        util.append(args, outer)
        util.append(args, inner)
        return fn(unpack(args))
    end
end

return async