blob: 58345b865a679883a04c49591cc6f299d58ee79e [file] [log] [blame] [edit]
// Copyright 2022 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_
#define LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_
#include <memory>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
#include <absl/base/attributes.h>
#include <base/functional/callback.h>
#include <base/functional/callback_helpers.h>
#include <base/logging.h>
#include <base/memory/scoped_refptr.h>
#include <base/memory/weak_ptr.h>
#include <base/task/bind_post_task.h>
#include <base/task/sequenced_task_runner.h>
#include <base/task/single_thread_task_runner.h>
#include <base/task/task_runner.h>
#include <base/threading/thread.h>
#include "libhwsec/backend/backend.h"
#include "libhwsec/error/tpm_retry_handler.h"
#include "libhwsec/middleware/function_name.h"
#include "libhwsec/middleware/middleware_derivative.h"
#include "libhwsec/middleware/middleware_owner.h"
#include "libhwsec/proxy/proxy.h"
#include "libhwsec/status.h"
#ifndef BUILD_LIBHWSEC
#error "Don't include this file outside libhwsec!"
#endif
// Middleware can be shared by multiple frontends.
// Converts asynchronous and synchronous calls to the backend.
// And doing some generic error handling, for example: communication error and
// auto reload key & session.
//
// Note: The middleware can maintain a standalone thread, or use the same task
// runner as the caller side.
//
// Note2: The move-only function parameters would not be copied, the other kinds
// of function parameters would be copied due to base::BindOnce.
namespace hwsec {
class Middleware {
public:
explicit Middleware(MiddlewareDerivative middleware_derivative)
: middleware_derivative_(middleware_derivative) {}
MiddlewareDerivative Derive() const { return middleware_derivative_; }
// Call the backend function synchronously.
template <auto Func, typename... Args>
auto CallSync(Args&&... args) const {
if constexpr (SubClassHelper<decltype(Func)>::type == CallType::kSync) {
// Calling sync backend function.
auto task = base::BindOnce(
&Middleware::DoSyncBackendCall<Func, decltype(ForwareParameter(
std::declval<Args>()))...>,
middleware_derivative_.middleware,
ForwareParameter(std::forward<Args>(args))...);
return RunBlockingTask(std::move(task));
} else if constexpr (SubClassHelper<decltype(Func)>::type ==
CallType::kAsync) {
// Calling async backend function.
using hwsec_foundation::status::MakeStatus;
using Result = SubClassResult<decltype(Func)>;
using Callback = SubClassCallback<decltype(Func)>;
base::WaitableEvent event(
base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED);
Result result =
MakeStatus<TPMError>("Unknown error", TPMRetryAction::kNoRetry);
Callback callback =
base::BindOnce([](Result* result_ptr,
Result value) { *result_ptr = std::move(value); },
&result)
.Then(base::BindOnce(&base::WaitableEvent::Signal,
base::Unretained(&event)));
base::OnceClosure task = base::BindOnce(
&Middleware::DoAsyncBackendCall<Func, decltype(ForwareParameter(
std::declval<Args>()))...>,
middleware_derivative_.middleware, std::move(callback),
ForwareParameter(std::forward<Args>(args))...);
middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(task));
event.Wait();
return result;
} else {
static_assert(always_false_v<decltype(Func)>, "Unsupported function!");
}
}
// Call the backend function asynchronously.
template <auto Func, typename Callback, typename... Args>
void CallAsync(Callback callback, Args&&... args) const {
CHECK(middleware_derivative_.task_runner);
SubClassCallback<decltype(Func)> reply = std::move(callback);
reply = base::BindPostTask(GetReplyRunner(), std::move(reply));
base::OnceClosure task = base::BindOnce(
&Middleware::CallAsyncInternal<Func, decltype(ForwareParameter(
std::declval<Args>()))...>,
middleware_derivative_.middleware, std::move(reply),
ForwareParameter(std::forward<Args>(args))...);
middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(task));
}
// Run a blocking task in the middleware.
template <typename Result>
Result RunBlockingTask(base::OnceCallback<Result()> task) const {
if (middleware_derivative_.thread_id == base::PlatformThread::CurrentId()) {
return std::move(task).Run();
}
CHECK(middleware_derivative_.task_runner);
base::WaitableEvent event(base::WaitableEvent::ResetPolicy::MANUAL,
base::WaitableEvent::InitialState::NOT_SIGNALED);
if constexpr (std::is_same_v<void, Result>) {
base::OnceClosure closure = std::move(task).Then(base::BindOnce(
&base::WaitableEvent::Signal, base::Unretained(&event)));
middleware_derivative_.task_runner->PostTask(FROM_HERE,
std::move(closure));
event.Wait();
return;
} else if constexpr (std::is_convertible_v<Status, Result>) {
using hwsec_foundation::status::MakeStatus;
Result result =
MakeStatus<TPMError>("Unknown error", TPMRetryAction::kNoRetry);
base::OnceClosure closure =
std::move(task)
.Then(base::BindOnce(
[](Result* result_ptr, Result value) {
*result_ptr = std::move(value);
},
&result))
.Then(base::BindOnce(&base::WaitableEvent::Signal,
base::Unretained(&event)));
middleware_derivative_.task_runner->PostTask(FROM_HERE,
std::move(closure));
event.Wait();
return result;
} else {
static_assert(always_false_v<Result>, "Unsupported blocking task!");
}
}
private:
// Type of the backend call.
enum class CallType {
// Synchronous backend call, and the function signature will be:
// Result SubClass::Function(Args...);
kSync,
// Asynchronous backend call, and the function signature will be:
// void SubClass::Function(base::OnceCallback<void(Result)>. Args...);
kAsync,
};
template <typename Func, typename = void>
struct SubClassHelper {
static_assert(sizeof(Func) == -1, "Unknown member function");
};
// SubClass helper for the synchronous backend call.
template <typename R, typename S, typename... Args>
struct SubClassHelper<R (S::*)(Args...),
std::enable_if_t<std::is_convertible_v<Status, R>>> {
inline constexpr static CallType type = CallType::kSync;
using Result = R;
using SubClass = S;
using Callback = base::OnceCallback<void(R)>;
};
// SubClass helper for the asynchronous backend call.
template <typename R, typename S, typename... Args>
struct SubClassHelper<void (S::*)(base::OnceCallback<void(R)>, Args...),
std::enable_if_t<std::is_convertible_v<Status, R>>> {
inline constexpr static CallType type = CallType::kAsync;
using Result = R;
using SubClass = S;
using Callback = base::OnceCallback<void(R)>;
};
template <typename Func>
using SubClassResult = typename SubClassHelper<Func>::Result;
template <typename Func>
using SubClassType = typename SubClassHelper<Func>::SubClass;
template <typename Func>
using SubClassCallback = typename SubClassHelper<Func>::Callback;
template <typename>
inline static constexpr bool always_false_v = false;
// The custom parameter forwarding rules.
template <typename T>
static T ForwareParameter(T&& t) {
// The rvalue should still be rvalue, because we have the ownership.
return t;
}
template <typename T>
static const T& ForwareParameter(T& t) {
// Add const for normal reference, because we don't have the ownership.
// base::BindOnce will copy const reference parameter when binding.
return t;
}
template <typename T>
static const T& ForwareParameter(const T& t) {
// The const reference would still be const reference.
// base::BindOnce will copy const reference parameter when binding.
return t;
}
// Get the quick result that is not related to the function itself.
template <auto Func>
static std::variant<SubClassResult<decltype(Func)>,
SubClassType<decltype(Func)>*>
GetQuickResult(base::WeakPtr<MiddlewareOwner> middleware) {
using hwsec_foundation::status::MakeStatus;
if (!middleware) {
return MakeStatus<TPMError>("No middleware", TPMRetryAction::kNoRetry);
}
#if USE_FUZZER
if (middleware->data_provider_) {
return FuzzedObject<SubClassResult<decltype(Func)>>()(
*middleware->data_provider_);
}
#endif
if (!middleware->backend_) {
return MakeStatus<TPMError>("No backend", TPMRetryAction::kNoRetry);
}
auto* sub = middleware->backend_->Get<SubClassType<decltype(Func)>>();
if (!sub) {
return MakeStatus<TPMError>("No sub class in backend",
TPMRetryAction::kNoRetry);
}
return sub;
}
// Call the synchronous backend call.
template <auto Func, typename... Args>
static SubClassResult<decltype(Func)> DoSyncBackendCall(
base::WeakPtr<MiddlewareOwner> middleware, Args... args) {
using Result = SubClassResult<decltype(Func)>;
using Type = SubClassType<decltype(Func)>;
std::variant<Result, Type*> quick_result = GetQuickResult<Func>(middleware);
if (Result* result = std::get_if<Result>(&quick_result)) {
return std::move(*result);
}
Type* sub = *std::get_if<Type*>(&quick_result);
for (TPMRetryHandler retry_handler;;) {
SubClassResult<decltype(Func)> result = (sub->*Func)(args...);
if (middleware->metrics_) {
middleware->metrics_->SendFuncResultToUMA(GetFuncName<Func>(),
result.status());
}
if (retry_handler.HandleResult(result, *middleware->backend_, args...)) {
return result;
}
}
}
// Call the asynchronous backend call.
template <auto Func, typename... Args>
static void DoAsyncBackendCall(base::WeakPtr<MiddlewareOwner> middleware,
SubClassCallback<decltype(Func)> callback,
Args... args) {
auto retry_handler = std::make_unique<TPMRetryHandler>();
// Using the decay type to make sure we are not putting dangling reference
// in the tuple.
auto args_tuple =
std::make_unique<std::tuple<std::decay_t<Args>...>>(std::move(args)...);
DoAsyncBackendCallInternal<Func>(
std::move(middleware), std::move(retry_handler), std::move(callback),
std::move(args_tuple), std::make_index_sequence<sizeof...(Args)>());
}
template <auto Func, typename ArgsTuple, std::size_t... I>
static void DoAsyncBackendCallInternal(
base::WeakPtr<MiddlewareOwner> middleware,
std::unique_ptr<TPMRetryHandler> retry_handler,
SubClassCallback<decltype(Func)> callback,
std::unique_ptr<ArgsTuple> args,
std::index_sequence<I...> idx_seq) {
using Result = SubClassResult<decltype(Func)>;
using Type = SubClassType<decltype(Func)>;
using Callback = SubClassCallback<decltype(Func)>;
std::variant<Result, Type*> quick_result = GetQuickResult<Func>(middleware);
if (Result* result = std::get_if<Result>(&quick_result)) {
std::move(callback).Run(std::move(*result));
return;
}
Type* sub = *std::get_if<Type*>(&quick_result);
// Note: The args tuple will be owned by the retry callback.
// We will transfer the ownership of the retry callback into the backend
// function, so the backend functions should be careful about not using the
// args after call or drop the callback.
ArgsTuple& args_ref = *args;
Callback retry_callback =
base::BindOnce(&HandleAsyncBackendCallRetry<Func, ArgsTuple, I...>,
std::move(middleware), std::move(retry_handler),
std::move(callback), std::move(args), idx_seq);
(sub->*Func)(std::move(retry_callback), std::get<I>(args_ref)...);
}
template <auto Func, typename ArgsTuple, std::size_t... I>
static void HandleAsyncBackendCallRetry(
base::WeakPtr<MiddlewareOwner> middleware,
std::unique_ptr<TPMRetryHandler> retry_handler,
SubClassCallback<decltype(Func)> callback,
std::unique_ptr<ArgsTuple> args,
std::index_sequence<I...> idx_seq,
SubClassResult<decltype(Func)> result) {
using hwsec_foundation::status::MakeStatus;
if (!middleware) {
std::move(callback).Run(
MakeStatus<TPMError>("No middleware", TPMRetryAction::kNoRetry));
return;
}
if (middleware->metrics_) {
middleware->metrics_->SendFuncResultToUMA(GetFuncName<Func>(),
result.status());
}
if (retry_handler->HandleResult(result, *middleware->backend_,
std::get<I>(*args)...)) {
std::move(callback).Run(std::move(result));
return;
}
DoAsyncBackendCallInternal<Func>(
std::move(middleware), std::move(retry_handler), std::move(callback),
std::move(args), idx_seq);
}
template <auto Func, typename... Args>
static void CallAsyncInternal(base::WeakPtr<MiddlewareOwner> middleware,
SubClassCallback<decltype(Func)> callback,
Args... args) {
if constexpr (SubClassHelper<decltype(Func)>::type == CallType::kSync) {
// Calling sync backend function.
std::move(callback).Run(DoSyncBackendCall<Func, Args...>(
std::move(middleware), ForwareParameter(std::move(args))...));
} else if constexpr (SubClassHelper<decltype(Func)>::type ==
CallType::kAsync) {
// Calling async backend function.
Middleware::DoAsyncBackendCall<Func, Args...>(
std::move(middleware), std::move(callback),
ForwareParameter(std::move(args))...);
} else {
static_assert(always_false_v<decltype(Func)>, "Unsupported function!");
}
}
static scoped_refptr<base::TaskRunner> GetReplyRunner() {
CHECK(base::SequencedTaskRunner::HasCurrentDefault());
return base::SequencedTaskRunner::GetCurrentDefault();
}
MiddlewareDerivative middleware_derivative_;
};
} // namespace hwsec
#endif // LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_