// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef VM_TOOLS_CONCIERGE_FUTURE_H_
#define VM_TOOLS_CONCIERGE_FUTURE_H_

#include <atomic>
#include <memory>
#include <tuple>
#include <utility>
#include <vector>

#include <base/bind.h>
#include <base/optional.h>
#include <base/run_loop.h>
#include <base/task_runner.h>
#include <base/threading/sequenced_task_runner_handle.h>

#include "vm_tools/concierge/apply_impl.h"

// A future class and utilities adapted to the Chrome OS code base. It can be
// used to post jobs to the same/different threads, and to add async support
// to methods. Please refer to FutureTest_Tutorial of future_test.cc on how to
// use this class.
//
// * Regarding "NoReject" (|Then| vs |ThenNoReject|,
//                       and |Async| vs |AsyncNoReject|)
//
// Use the NoReject variant if there is no need to reject a promise in your
// function. It allows you to return the output directly without using
// `return vm_tools::Resolve<T>(val);`, which reduces the amount of boilerplate
// code.
//
// The non "NoReject" func gives the freedom to use both |Resolve| and |Reject|.
//
// * Support of std::tuple and std::array
//
// If the value type of the Future is either std::tuple or std::array, the value
// is unpacked automatically for the next |Then| func. See FutureTest_Tuple
// and FutureTest_Array for example.
namespace vm_tools {

template <typename T, typename Error>
class Future;

template <typename T, typename Error>
class Promise;

// A struct to store both the value and the error for |Future::Get|.
// In a common C++ implementation of futures, this is not need as the exception
// is thrown if any when |Get| is called.
template <typename T,
          typename Error = void,
          class Enable1 = void,
          class Enable2 = void>
struct GetResult;

template <typename T, typename Error>
struct GetResult<T,
                 Error,
                 typename std::enable_if_t<!std::is_void<T>::value>,
                 typename std::enable_if_t<!std::is_void<Error>::value>> {
  T val;
  Error err;
  bool rejected = false;
};

template <typename T, typename Error>
struct GetResult<T,
                 Error,
                 typename std::enable_if_t<std::is_void<T>::value>,
                 typename std::enable_if_t<!std::is_void<Error>::value>> {
  Error err;
  bool rejected = false;
};

template <typename T, typename Error>
struct GetResult<T,
                 Error,
                 typename std::enable_if_t<!std::is_void<T>::value>,
                 typename std::enable_if_t<std::is_void<Error>::value>> {
  T val;
  bool rejected = false;
};

template <typename T, typename Error>
struct GetResult<T,
                 Error,
                 typename std::enable_if_t<std::is_void<T>::value>,
                 typename std::enable_if_t<std::is_void<Error>::value>> {
  bool rejected = false;
};

template <typename T, typename Error = void>
static typename std::enable_if_t<!std::is_void<T>::value, GetResult<T, Error>>
Resolve(T val) {
  GetResult<T, Error> ret;
  ret.val = std::move(val);
  ret.rejected = false;
  return ret;
}

template <typename T, typename Error = void>
static typename std::enable_if_t<std::is_void<T>::value, GetResult<T, Error>>
Resolve() {
  GetResult<void, Error> ret;
  ret.rejected = false;
  return ret;
}

template <typename T, typename Error = void>
typename std::enable_if_t<!std::is_void<Error>::value, GetResult<T, Error>>
Reject(Error err) {
  GetResult<T, Error> ret;
  ret.err = std::move(err);
  ret.rejected = true;
  return ret;
}

template <typename T, typename Error = void>
typename std::enable_if_t<std::is_void<Error>::value, GetResult<T, void>>
Reject() {
  GetResult<T, void> ret;
  ret.rejected = true;
  return ret;
}

namespace internal {

template <typename T, typename Error>
struct SharedState {
  mutable base::Lock mutex;
  GetResult<T, Error> ret;
  bool done = false;
  scoped_refptr<base::TaskRunner> task_runner;
  base::OnceClosure then_func;
};

template <typename F>
struct get_future_type;

template <template <typename, typename> class F, typename T, typename Error>
struct get_future_type<F<T, Error>> {
  using type = T;
};

template <class T, class Error, class U = T>
struct is_future : std::false_type {};

template <class F, class Error>
struct is_future<F, Error, Future<typename get_future_type<F>::type, Error>>
    : std::true_type {};
};  // namespace internal

// Error: User defined error type used in |Reject| and |OnReject|.
template <typename T, typename Error = void>
class Future {
 public:
  Future() = default;
  explicit Future(std::shared_ptr<internal::SharedState<T, Error>> state)
      : state_(std::move(state)) {}
  Future(Future&&) = default;
  Future(const Future&) = delete;
  Future& operator=(const Future&) = delete;

  Future& operator=(Future&&) = default;

  // |func| will be posted to the task_runner when this future is fulfilled by
  // returning |vm_tools::Resolve<T>(val)| or |vm_tools::Reject<T>(err)|.
  //
  // Returns the future of the posted func. Task_runner of this future class is
  // inherited by the returned future.
  //
  // |OnReject| can be used to handle the rejection if there was a reject. If
  // the rejection is not handled, the subsequent Then funcs will be bypassed
  // and rejection signal will be propagated.
  template <typename T_then, typename... Ts>
  Future<T_then, Error> Then(
      base::OnceCallback<GetResult<T_then, Error>(Ts...)> func);
  template <typename T_then, typename... Ts>
  Future<T_then, Error> ThenNoReject(base::OnceCallback<T_then(Ts...)> func);

  // |func| is triggered if |vm_tools::Reject<T>(err)| is returned by previous
  // |Then| func. The chain can be resumed by returning
  // |vm_tools::Resolve<T>(val)|, or keep skipping by returning
  // |vm_tools::Reject<T>(err)| within |func|.
  //
  // Returns the future of the posted func. task_runner of this future class is
  // inherited by the returned future.
  template <typename... ErrorOrVoid>
  Future<T, Error> OnReject(
      base::OnceCallback<GetResult<T, Error>(ErrorOrVoid...)> func);

  // Use a RunLoop to wait for the promise to be fulfilled, return
  // the result and reset the shared state. In other words, the future becomes
  // invalid after Get() returns.
  GetResult<T, Error> Get(
      base::RunLoop::Type type = base::RunLoop::Type::kNestableTasksAllowed);

  // TODO(woodychow): Implement a wait function with timeout

  // Update the |task_runner|. |Then|/|OnReject| functions will be posted to
  // this task_runner
  Future<T, Error> Via(scoped_refptr<base::TaskRunner> task_runner) {
    DCHECK(task_runner);
    base::AutoLock guard(state_->mutex);
    state_->task_runner = task_runner;
    return std::move(*this);
  }

  // Returns true if the promise has been fulfilled. False otherwise.
  bool IsDone() const {
    base::AutoLock guard(state_->mutex);
    return state_->done;
  }

  // Flatten a nested future. Useful when making an async call within an async
  // function.
  template <typename U = T, typename UError = Error>
  typename std::enable_if_t<internal::is_future<U, UError>::value, T>
  Flatten() {
    return std::move(*this).Then(base::BindOnce([](T f) { return f.Get(); }));
  }

 private:
  // This function should only be called in the internal Then wrapper, where the
  // promise is fulfilled and the val is guaranteed to be there.
  // Wait, this assumes the original func and the Then func are executed in the
  // same sequence
  GetResult<T, Error> GetHelper();

  template <typename U = T, typename UError = Error>
  static typename std::enable_if_t<std::is_void<UError>::value>
  RejectFuncHelper(base::OnceCallback<GetResult<T, void>()> reject_func,
                   Promise<T, void> p,
                   GetResult<T, void> ret) {
    p.SetResult(std::move(reject_func).Run());
  }
  template <typename U = T, typename UError = Error>
  static typename std::enable_if_t<!std::is_void<UError>::value>
  RejectFuncHelper(base::OnceCallback<GetResult<T, UError>(UError)> reject_func,
                   Promise<T, UError> p,
                   GetResult<T, UError> ret) {
    p.SetResult(std::move(reject_func).Run(std::move(ret.err)));
  }

  template <typename T_then, typename UError = Error>
  static typename std::enable_if_t<std::is_void<UError>::value>
  ResolveFuncHelper(base::OnceCallback<GetResult<T_then, Error>()> resolve_func,
                    Promise<T_then, Error> p,
                    GetResult<T, Error> ret) {
    p.SetResult(std::move(resolve_func).Run());
  }

  template <typename T_then, typename... Ts, typename UError = Error>
  static typename std::enable_if_t<std::is_void<UError>::value>
  ResolveFuncHelper(
      base::OnceCallback<GetResult<T_then, Error>(Ts...)> resolve_func,
      Promise<T_then, Error> p,
      GetResult<T, Error> ret) {
    p.SetResult(internal::Apply(std::move(resolve_func), std::move(ret.val)));
  }

  template <typename T_then, typename UError = Error>
  static typename std::enable_if_t<std::is_void<UError>::value> PropagateReject(
      Promise<T_then, void> p, GetResult<T, void> ret) {
    p.SetResult(Reject<T_then>());
  }

  template <typename T_then, typename UError = Error>
  static typename std::enable_if_t<!std::is_void<UError>::value>
  PropagateReject(Promise<T_then, UError> p, GetResult<T, UError> ret) {
    p.SetResult(Reject<T_then, Error>(ret.err));
  }

  template <typename T_then, typename... Ts>
  Future<T_then, Error> ThenHelper(
      scoped_refptr<base::TaskRunner> task_runner,
      base::OnceCallback<void(Promise<T_then, Error>, GetResult<T, Error>)>
          resolve_func,
      base::OnceCallback<void(Promise<T_then, Error>, GetResult<T, Error>)>
          reject_func);

  std::shared_ptr<internal::SharedState<T, Error>> state_;
};

template <typename T, typename Error = void>
class Promise {
 public:
  Promise() { state_ = std::make_shared<internal::SharedState<T, Error>>(); }
  explicit Promise(std::shared_ptr<internal::SharedState<T, Error>> state)
      : state_(std::move(state)) {}
  Promise(Promise&&) = default;
  Promise(const Promise&) = delete;
  Promise& operator=(const Promise&) = delete;

  Promise& operator=(Promise&&) = default;

  // Returns a future that can be used to wait for this promise to be fulfilled.
  Future<T, Error> GetFuture(scoped_refptr<base::TaskRunner> task_runner) {
    base::AutoLock guard(state_->mutex);
    state_->task_runner = task_runner;
    return Future<T, Error>(state_);
  }

  // Fulfill this promise. The shared state will be released upon this.
  template <typename U = T>
  typename std::enable_if_t<std::is_void<U>::value> SetValue();
  template <typename U = T>
  typename std::enable_if_t<!std::is_void<U>::value> SetValue(U val);

  // Reject this promise. The shared state will be released upon this.
  template <typename UError = Error>
  typename std::enable_if_t<std::is_void<UError>::value> Reject();
  template <typename UError = Error>
  typename std::enable_if_t<!std::is_void<UError>::value> Reject(UError err);

  void SetResult(GetResult<T, Error> ret);

 private:
  // Lock state_.mutex before calling this
  void SetValueHelperLocked();

  std::shared_ptr<internal::SharedState<T, Error>> state_;
};

// ------ Promise impl ------

template <typename T, typename Error>
void Promise<T, Error>::SetValueHelperLocked() {
  DCHECK(!state_->done);
  state_->done = true;

  // Handle "then"
  if (!state_->then_func.is_null()) {
    std::move(state_->then_func).Run();
  }
}

template <typename T, typename Error>
template <typename U>
typename std::enable_if_t<std::is_void<U>::value>
Promise<T, Error>::SetValue() {
  base::AutoLock guard(state_->mutex);
  state_->ret = vm_tools::Resolve<void, Error>();
  SetValueHelperLocked();
}

template <typename T, typename Error>
template <typename U>
typename std::enable_if_t<!std::is_void<U>::value> Promise<T, Error>::SetValue(
    U val) {
  base::AutoLock guard(state_->mutex);
  state_->ret = vm_tools::Resolve<T, Error>(std::move(val));
  SetValueHelperLocked();
}

template <typename T, typename Error>
template <typename UError>
typename std::enable_if_t<std::is_void<UError>::value>
Promise<T, Error>::Reject() {
  base::AutoLock guard(state_->mutex);
  state_->ret = vm_tools::Reject<T, void>();
  state_->done = true;

  if (!state_->then_func.is_null()) {
    std::move(state_->then_func).Run();
  }
}

template <typename T, typename Error>
template <typename UError>
typename std::enable_if_t<!std::is_void<UError>::value>
Promise<T, Error>::Reject(UError err) {
  base::AutoLock guard(state_->mutex);
  state_->ret = vm_tools::Reject<T, Error>(std::move(err));
  state_->done = true;

  if (!state_->then_func.is_null()) {
    std::move(state_->then_func).Run();
  }
}

template <typename T, typename Error>
void Promise<T, Error>::SetResult(GetResult<T, Error> ret) {
  base::AutoLock guard(state_->mutex);
  state_->ret = std::move(ret);
  SetValueHelperLocked();
}

// ------ Future impl ------

template <typename T, typename Error>
template <typename T_then, typename... Ts>
Future<T_then, Error> Future<T, Error>::ThenHelper(
    scoped_refptr<base::TaskRunner> task_runner,
    base::OnceCallback<void(Promise<T_then, Error>, GetResult<T, Error>)>
        resolve_func,
    base::OnceCallback<void(Promise<T_then, Error>, GetResult<T, Error>)>
        reject_func) {
  base::AutoLock guard(state_->mutex);
  if (task_runner) {
    state_->task_runner = task_runner;
  }
  Promise<T_then, Error> promise;
  Future<T_then, Error> future = promise.GetFuture(state_->task_runner);
  internal::SharedState<T, Error>* pState = state_.get();
  base::OnceCallback<void(Future<T, Error>, Promise<T_then, Error>)>
      wrapped_func = base::BindOnce(
          [](base::OnceCallback<void(Promise<T_then, Error>,
                                     GetResult<T, Error>)> resolve_func,
             base::OnceCallback<void(Promise<T_then, Error>,
                                     GetResult<T, Error>)> reject_func,
             Future<T, Error> old_future, Promise<T_then, Error> p) {
            GetResult<T, Error> ret = old_future.Get();
            if (ret.rejected) {
              std::move(reject_func).Run(std::move(p), std::move(ret));
            } else {
              std::move(resolve_func).Run(std::move(p), std::move(ret));
            }
          },
          std::move(resolve_func), std::move(reject_func));

  base::OnceClosure post_func = base::BindOnce(
      [](scoped_refptr<base::TaskRunner> task_runner, base::OnceClosure func) {
        CHECK(task_runner);
        task_runner->PostTask(FROM_HERE, std::move(func));
      },
      state_->task_runner,
      base::BindOnce(std::move(wrapped_func), std::move(*this),
                     std::move(promise)));
  if (pState->done) {
    // post immediately
    std::move(post_func).Run();
  } else {
    // Promise::SetValue/Reject will run this func
    pState->then_func = std::move(post_func);
  }
  return future;
}

template <typename T, typename Error>
template <typename T_then, typename... Ts>
Future<T_then, Error> Future<T, Error>::Then(
    base::OnceCallback<GetResult<T_then, Error>(Ts...)> func) {
  return ThenHelper<T_then, Ts...>(
      nullptr,
      base::BindOnce(
          [](base::OnceCallback<GetResult<T_then, Error>(Ts...)> resolve_func,
             Promise<T_then, Error> p, GetResult<T, Error> ret) {
            ResolveFuncHelper(std::move(resolve_func), std::move(p),
                              std::move(ret));
          },
          std::move(func)),
      base::BindOnce(&Future<T, Error>::PropagateReject<T_then>));
}

namespace internal {

template <typename Error, typename T_then, typename... Ts>
typename std::enable_if_t<!std::is_void<T_then>::value,
                          base::OnceCallback<GetResult<T_then, Error>(Ts...)>>
FutureBind(base::OnceCallback<T_then(Ts...)> func) {
  return base::BindOnce(
      [](base::OnceCallback<T_then(Ts...)> func, Ts... args) {
        return Resolve<T_then, Error>(std::move(func).Run(std::move(args)...));
      },
      std::move(func));
}

template <typename Error, typename T_then, typename... Ts>
typename std::enable_if_t<std::is_void<T_then>::value,
                          base::OnceCallback<GetResult<T_then, Error>(Ts...)>>
FutureBind(base::OnceCallback<T_then(Ts...)> func) {
  return base::BindOnce(
      [](base::OnceCallback<void(Ts...)> func, Ts... args) {
        std::move(func).Run(std::move(args)...);
        return Resolve<void, Error>();
      },
      std::move(func));
}
};  // namespace internal

template <typename T, typename Error>
template <typename T_then, typename... Ts>
Future<T_then, Error> Future<T, Error>::ThenNoReject(
    base::OnceCallback<T_then(Ts...)> func) {
  return Then(internal::FutureBind<Error>(std::move(func)));
}

template <typename T, typename Error>
template <typename... ErrorOrVoid>
Future<T, Error> Future<T, Error>::OnReject(
    base::OnceCallback<GetResult<T, Error>(ErrorOrVoid...)> func) {
  return ThenHelper<T, T>(
      nullptr, base::BindOnce([](Promise<T, Error> p, GetResult<T, Error> ret) {
        p.SetResult(std::move(ret));
      }),
      base::BindOnce(
          [](base::OnceCallback<GetResult<T, Error>(ErrorOrVoid...)> func,
             Promise<T, Error> p, GetResult<T, Error> ret) {
            Future<T, Error>::RejectFuncHelper(std::move(func), std::move(p),
                                               std::move(ret));
          },
          std::move(func)));
}

template <typename T, typename Error>
GetResult<T, Error> Future<T, Error>::Get(base::RunLoop::Type type) {
  {
    base::ReleasableAutoLock lock(&state_->mutex);
    if (!state_->done) {
      base::RunLoop loop(type);
      DCHECK(state_->then_func.is_null());
      state_->then_func = loop.QuitClosure();
      lock.Release();
      loop.Run();
    }
  }
  return GetHelper();
}

template <typename T, typename Error>
GetResult<T, Error> Future<T, Error>::GetHelper() {
  GetResult<T, Error> ret;
  {
    base::AutoLock guard(state_->mutex);
    DCHECK(state_->done);
    ret = std::move(state_->ret);
  }
  state_.reset();
  return ret;
}

/* ------ Non class method declarations ------*/

// Post |func| to |task_runner|, and return a future that will be ready upon
// completion of the posted |func|
template <typename T, typename Error = void>
Future<T, Error> Async(scoped_refptr<base::TaskRunner> task_runner,
                       base::OnceCallback<GetResult<T, Error>()> func);
template <typename T, typename Error = void>
Future<T, Error> AsyncNoReject(scoped_refptr<base::TaskRunner> task_runner,
                               base::OnceCallback<T()> func);

// Returns a future that will be ready when all the given futures are ready
// If any of the given futures is rejected, the returned future will be rejected
// as well
template <typename T, typename Error>
Future<std::vector<T>, Error> Collect(
    scoped_refptr<base::TaskRunner> task_runner,
    std::vector<Future<T, Error>> futures);

// Returns a future that has already been resolved with the given |val|.
// This is useful for removing boilerplate code
template <typename T, typename Error = void>
std::enable_if_t<!std::is_void<T>::value, Future<T, Error>> ResolvedFuture(
    T val,
    scoped_refptr<base::TaskRunner> task_runner =
        base::SequencedTaskRunnerHandle::Get());
template <typename T, typename Error = void>
std::enable_if_t<std::is_void<T>::value, Future<T, Error>> ResolvedFuture(
    scoped_refptr<base::TaskRunner> task_runner =
        base::SequencedTaskRunnerHandle::Get());

/* ------ Non class method implementation ------ */

template <typename T, typename Error>
Future<T, Error> Flatten(Future<Future<T, Error>, Error> f) {
  return f.Then(base::BindOnce([](Future<T, Error> f) { return f.Get(); }));
}

template <typename T, typename Error>
Future<T, Error> Async(scoped_refptr<base::TaskRunner> task_runner,
                       base::OnceCallback<GetResult<T, Error>()> func) {
  Promise<T, Error> p;
  Future<T, Error> future = p.GetFuture(task_runner);
  task_runner->PostTask(FROM_HERE,
                        base::BindOnce(
                            [](Promise<T, Error> p,
                               base::OnceCallback<GetResult<T, Error>()> func) {
                              p.SetResult(std::move(func).Run());
                            },
                            std::move(p), std::move(func)));
  return future;
}

template <typename T, typename Error>
Future<T, Error> AsyncNoReject(scoped_refptr<base::TaskRunner> task_runner,
                               base::OnceCallback<T()> func) {
  Promise<T, Error> p;
  Future<T, Error> future = p.GetFuture(task_runner);
  task_runner->PostTask(
      FROM_HERE,
      base::BindOnce(
          [](Promise<T, Error> p,
             base::OnceCallback<GetResult<T, Error>()> func) {
            p.SetResult(std::move(func).Run());
          },
          std::move(p), internal::FutureBind<Error>(std::move(func))));
  return future;
}

template <typename T, typename Error>
Future<std::vector<T>, Error> Collect(
    scoped_refptr<base::TaskRunner> task_runner,
    std::vector<Future<T, Error>> futures) {
  struct Context {
    explicit Context(size_t n) { values.resize(n); }
    void Reject(Error e) {
      if (!rejected.fetch_or(1)) {  // only reject once
        promise.Reject(std::move(e));
      }
    }
    ~Context() {
      if (!rejected.load()) {
        promise.SetValue(std::move(values));
      }
    }
    std::atomic<uint8_t> rejected{0};
    Promise<std::vector<T>> promise;
    std::vector<T> values;
  };

  std::shared_ptr<Context> ctx = std::make_shared<Context>(futures.size());

  for (size_t i = 0; i < futures.size(); ++i) {
    futures[i]
        .Via(task_runner)
        .Then(base::BindOnce(
            [](std::shared_ptr<Context> ctx, size_t i, T val) {
              ctx->values[i] = std::move(val);
              return Resolve<void>();
            },
            ctx, i))
        .OnReject(base::BindOnce(
            [](std::shared_ptr<Context> ctx, Error e) {
              ctx->Reject(std::move(e));
              // Whatever as the future returned by this OnReject is not used
              return Resolve<void>();
            },
            ctx));
  }

  return ctx->promise.GetFuture(task_runner);
}

template <typename T>
Future<std::vector<T>, void> Collect(
    scoped_refptr<base::TaskRunner> task_runner,
    std::vector<Future<T, void>> futures) {
  struct Context {
    explicit Context(size_t n) { values.resize(n); }
    void Reject() {
      if (!rejected.fetch_or(1)) {  // only reject once
        promise.Reject();
      }
    }
    ~Context() {
      if (!rejected.load()) {
        promise.SetValue(std::move(values));
      }
    }
    std::atomic<uint8_t> rejected{0};
    Promise<std::vector<T>> promise;
    std::vector<T> values;
  };

  std::shared_ptr<Context> ctx = std::make_shared<Context>(futures.size());

  for (size_t i = 0; i < futures.size(); ++i) {
    futures[i]
        .Via(task_runner)
        .Then(base::BindOnce(
            [](std::shared_ptr<Context> ctx, size_t i, T val) {
              ctx->values[i] = std::move(val);
              return Resolve<void>();
            },
            ctx, i))
        .OnReject(base::BindOnce(
            [](std::shared_ptr<Context> ctx) {
              ctx->Reject();
              // Whatever as the future returned by this OnReject is not used
              return Resolve<void>();
            },
            ctx));
  }

  return ctx->promise.GetFuture(task_runner);
}

template <typename T, typename Error>
std::enable_if_t<!std::is_void<T>::value, Future<T, Error>> ResolvedFuture(
    T val, scoped_refptr<base::TaskRunner> task_runner) {
  Promise<T, Error> promise;
  Future<T, Error> future = promise.GetFuture(task_runner);
  promise.SetValue(std::move(val));
  return future;
}

template <typename T, typename Error>
std::enable_if_t<std::is_void<T>::value, Future<T, Error>> ResolvedFuture(
    scoped_refptr<base::TaskRunner> task_runner) {
  Promise<void, Error> promise;
  Future<void, Error> future = promise.GetFuture(task_runner);
  promise.SetValue();
  return future;
}

}  // namespace vm_tools

#endif  // VM_TOOLS_CONCIERGE_FUTURE_H_
