| // Copyright 2022 The ChromiumOS Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_ |
| #define LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_ |
| |
| #include <concepts> |
| #include <memory> |
| #include <string> |
| #include <tuple> |
| #include <type_traits> |
| #include <utility> |
| #include <variant> |
| |
| #include <absl/base/attributes.h> |
| #include <base/functional/callback.h> |
| #include <base/functional/callback_helpers.h> |
| #include <base/logging.h> |
| #include <base/memory/scoped_refptr.h> |
| #include <base/memory/weak_ptr.h> |
| #include <base/strings/stringprintf.h> |
| #include <base/task/bind_post_task.h> |
| #include <base/task/sequenced_task_runner.h> |
| #include <base/task/single_thread_task_runner.h> |
| #include <base/task/task_runner.h> |
| #include <base/threading/thread.h> |
| |
| #include "libhwsec/backend/backend.h" |
| #include "libhwsec/error/tpm_retry_action.h" |
| #include "libhwsec/error/tpm_retry_handler.h" |
| #include "libhwsec/middleware/function_name.h" |
| #include "libhwsec/middleware/middleware_derivative.h" |
| #include "libhwsec/middleware/middleware_owner.h" |
| #include "libhwsec/middleware/subclass_helper.h" |
| #include "libhwsec/proxy/proxy.h" |
| #include "libhwsec/status.h" |
| |
| #ifndef BUILD_LIBHWSEC |
| #error "Don't include this file outside libhwsec!" |
| #endif |
| |
| // Middleware can be shared by multiple frontends. |
| // Converts asynchronous and synchronous calls to the backend. |
| // And doing some generic error handling, for example: communication error and |
| // auto reload key & session. |
| // |
| // Note: The middleware can maintain a standalone thread, or use the same task |
| // runner as the caller side. |
| // |
| // Note2: The move-only function parameters would not be copied, the other kinds |
| // of function parameters would be copied due to base::BindOnce. |
| |
| namespace hwsec { |
| |
| class Middleware { |
| public: |
| explicit Middleware(MiddlewareDerivative middleware_derivative) |
| : middleware_derivative_(middleware_derivative) {} |
| |
| MiddlewareDerivative Derive() const { return middleware_derivative_; } |
| |
| // Call the backend function synchronously. |
| template <auto Func, typename... Args> |
| requires(BackendMethod<decltype(Func)> && |
| ValidBackendMethodArgs<decltype(Func), Args...>) |
| auto CallSync(Args&&... args) const { |
| return CallSyncInternal<Func>(std::forward<Args>(args)...); |
| } |
| |
| // Call the backend function asynchronously. |
| template <auto Func, typename Callback, typename... Args> |
| requires(BackendMethod<decltype(Func)> && |
| ValidBackendMethodArgs<decltype(Func), Args...>) |
| void CallAsync(Callback callback, Args&&... args) const { |
| CHECK(middleware_derivative_.task_runner); |
| |
| SubClassCallback<decltype(Func)> reply = std::move(callback); |
| reply = base::BindPostTask(GetReplyRunner(), std::move(reply)); |
| base::OnceClosure task = base::BindOnce( |
| &Middleware::CallAsyncInternal<Func, decltype(ForwardParameter( |
| std::declval<Args>()))...>, |
| middleware_derivative_.middleware, std::move(reply), |
| ForwardParameter(std::forward<Args>(args))...); |
| middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(task)); |
| } |
| |
| // Run a blocking task without return value in the middleware. |
| void RunBlockingTask(base::OnceCallback<void()> task) const { |
| if (middleware_derivative_.thread_id == base::PlatformThread::CurrentId()) { |
| return std::move(task).Run(); |
| } |
| |
| CHECK(middleware_derivative_.task_runner); |
| |
| base::WaitableEvent event(base::WaitableEvent::ResetPolicy::MANUAL, |
| base::WaitableEvent::InitialState::NOT_SIGNALED); |
| |
| base::OnceClosure closure = std::move(task).Then( |
| base::BindOnce(&base::WaitableEvent::Signal, base::Unretained(&event))); |
| middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(closure)); |
| event.Wait(); |
| return; |
| } |
| |
| // Run a blocking task with return value in the middleware. |
| template <typename Result> |
| requires(std::constructible_from<Result, Status>) |
| Result RunBlockingTask(base::OnceCallback<Result()> task) const { |
| if (middleware_derivative_.thread_id == base::PlatformThread::CurrentId()) { |
| return std::move(task).Run(); |
| } |
| |
| CHECK(middleware_derivative_.task_runner); |
| |
| base::WaitableEvent event(base::WaitableEvent::ResetPolicy::MANUAL, |
| base::WaitableEvent::InitialState::NOT_SIGNALED); |
| |
| using hwsec_foundation::status::MakeStatus; |
| Result result = |
| MakeStatus<TPMError>("Unknown error", TPMRetryAction::kNoRetry); |
| base::OnceClosure closure = |
| std::move(task) |
| .Then(base::BindOnce( |
| [](Result* result_ptr, Result value) { |
| *result_ptr = std::move(value); |
| }, |
| &result)) |
| .Then(base::BindOnce(&base::WaitableEvent::Signal, |
| base::Unretained(&event))); |
| middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(closure)); |
| event.Wait(); |
| return result; |
| } |
| |
| private: |
| // Get the quick result that is not related to the function itself. |
| template <auto Func> |
| requires(BackendMethod<decltype(Func)>) |
| static std::variant<SubClassResult<decltype(Func)>, |
| SubClassType<decltype(Func)>*> |
| GetQuickResult(base::WeakPtr<MiddlewareOwner> middleware) { |
| using hwsec_foundation::status::MakeStatus; |
| |
| if (!middleware) { |
| return MakeStatus<TPMError>("No middleware", TPMRetryAction::kNoRetry); |
| } |
| |
| #if USE_FUZZER |
| if (middleware->data_provider_) { |
| return FuzzedObject<SubClassResult<decltype(Func)>>()( |
| *middleware->data_provider_); |
| } |
| #endif |
| |
| if (!middleware->GetBackend()) { |
| return MakeStatus<TPMError>("No backend", TPMRetryAction::kNoRetry); |
| } |
| |
| auto* sub = middleware->GetBackend()->Get<SubClassType<decltype(Func)>>(); |
| if (!sub) { |
| return MakeStatus<TPMError>("No sub class in backend", |
| TPMRetryAction::kNoRetry); |
| } |
| |
| return sub; |
| } |
| |
| // Call the synchronous backend call. |
| template <auto Func, typename... Args> |
| requires(SyncBackendMethod<decltype(Func)>) |
| static SubClassResult<decltype(Func)> DoSyncBackendCall( |
| base::WeakPtr<MiddlewareOwner> middleware, Args... args) { |
| using Result = SubClassResult<decltype(Func)>; |
| using Type = SubClassType<decltype(Func)>; |
| |
| std::variant<Result, Type*> quick_result = GetQuickResult<Func>(middleware); |
| if (Result* result = std::get_if<Result>(&quick_result)) { |
| return std::move(*result); |
| } |
| |
| Type* sub = *std::get_if<Type*>(&quick_result); |
| |
| for (TPMRetryHandler retry_handler;;) { |
| SubClassResult<decltype(Func)> result = (sub->*Func)(args...); |
| |
| TrackFuncResult(GetFuncName<Func>(), middleware->GetMetrics(), result); |
| |
| if (retry_handler.HandleResult(result, *middleware->GetBackend(), |
| middleware->GetMetrics(), args...)) { |
| return result; |
| } |
| } |
| } |
| |
| // Call the asynchronous backend call. |
| template <auto Func, typename... Args> |
| requires(AsyncBackendMethod<decltype(Func)>) |
| static void DoAsyncBackendCall(base::WeakPtr<MiddlewareOwner> middleware, |
| SubClassCallback<decltype(Func)> callback, |
| Args... args) { |
| auto retry_handler = std::make_unique<TPMRetryHandler>(); |
| |
| // Using the decay type to make sure we are not putting dangling reference |
| // in the tuple. |
| auto args_tuple = |
| std::make_unique<std::tuple<std::decay_t<Args>...>>(std::move(args)...); |
| |
| DoAsyncBackendCallInternal<Func>( |
| std::move(middleware), std::move(retry_handler), std::move(callback), |
| std::move(args_tuple), std::make_index_sequence<sizeof...(Args)>()); |
| } |
| |
| template <auto Func, typename ArgsTuple, std::size_t... I> |
| requires(AsyncBackendMethod<decltype(Func)>) |
| static void DoAsyncBackendCallInternal( |
| base::WeakPtr<MiddlewareOwner> middleware, |
| std::unique_ptr<TPMRetryHandler> retry_handler, |
| SubClassCallback<decltype(Func)> callback, |
| std::unique_ptr<ArgsTuple> args, |
| std::index_sequence<I...> idx_seq) { |
| using Result = SubClassResult<decltype(Func)>; |
| using Type = SubClassType<decltype(Func)>; |
| using Callback = SubClassCallback<decltype(Func)>; |
| |
| std::variant<Result, Type*> quick_result = GetQuickResult<Func>(middleware); |
| if (Result* result = std::get_if<Result>(&quick_result)) { |
| std::move(callback).Run(std::move(*result)); |
| return; |
| } |
| |
| Type* sub = *std::get_if<Type*>(&quick_result); |
| |
| // Note: The args tuple will be owned by the retry callback. |
| // We will transfer the ownership of the retry callback into the backend |
| // function, so the backend functions should be careful about not using the |
| // args after call or drop the callback. |
| ArgsTuple& args_ref = *args; |
| |
| Callback retry_callback = |
| base::BindOnce(&HandleAsyncBackendCallRetry<Func, ArgsTuple, I...>, |
| std::move(middleware), std::move(retry_handler), |
| std::move(callback), std::move(args), idx_seq); |
| |
| (sub->*Func)(std::move(retry_callback), std::get<I>(args_ref)...); |
| } |
| |
| template <auto Func, typename ArgsTuple, std::size_t... I> |
| requires(AsyncBackendMethod<decltype(Func)>) |
| static void HandleAsyncBackendCallRetry( |
| base::WeakPtr<MiddlewareOwner> middleware, |
| std::unique_ptr<TPMRetryHandler> retry_handler, |
| SubClassCallback<decltype(Func)> callback, |
| std::unique_ptr<ArgsTuple> args, |
| std::index_sequence<I...> idx_seq, |
| SubClassResult<decltype(Func)> result) { |
| using hwsec_foundation::status::MakeStatus; |
| |
| if (!middleware) { |
| std::move(callback).Run( |
| MakeStatus<TPMError>("No middleware", TPMRetryAction::kNoRetry)); |
| return; |
| } |
| |
| TrackFuncResult(GetFuncName<Func>(), middleware->GetMetrics(), result); |
| |
| if (retry_handler->HandleResult(result, *middleware->GetBackend(), |
| middleware->GetMetrics(), |
| std::get<I>(*args)...)) { |
| std::move(callback).Run(std::move(result)); |
| return; |
| } |
| |
| DoAsyncBackendCallInternal<Func>( |
| std::move(middleware), std::move(retry_handler), std::move(callback), |
| std::move(args), idx_seq); |
| } |
| |
| // Call the synchronous backend function synchronously. |
| template <auto Func, typename... Args> |
| requires(SyncBackendMethod<decltype(Func)>) |
| auto CallSyncInternal(Args&&... args) const { |
| // Calling sync backend function. |
| auto task = base::BindOnce( |
| &Middleware::DoSyncBackendCall<Func, decltype(ForwardParameter( |
| std::declval<Args>()))...>, |
| middleware_derivative_.middleware, |
| ForwardParameter(std::forward<Args>(args))...); |
| return RunBlockingTask(std::move(task)); |
| } |
| |
| // Call the asynchronous backend function synchronously. |
| template <auto Func, typename... Args> |
| requires(AsyncBackendMethod<decltype(Func)>) |
| auto CallSyncInternal(Args&&... args) const { |
| // Calling async backend function. |
| using hwsec_foundation::status::MakeStatus; |
| using Result = SubClassResult<decltype(Func)>; |
| using Callback = SubClassCallback<decltype(Func)>; |
| |
| base::WaitableEvent event(base::WaitableEvent::ResetPolicy::MANUAL, |
| base::WaitableEvent::InitialState::NOT_SIGNALED); |
| |
| Result result = |
| MakeStatus<TPMError>("Unknown error", TPMRetryAction::kNoRetry); |
| Callback callback = |
| base::BindOnce([](Result* result_ptr, |
| Result value) { *result_ptr = std::move(value); }, |
| &result) |
| .Then(base::BindOnce(&base::WaitableEvent::Signal, |
| base::Unretained(&event))); |
| |
| base::OnceClosure task = base::BindOnce( |
| &Middleware::DoAsyncBackendCall<Func, decltype(ForwardParameter( |
| std::declval<Args>()))...>, |
| middleware_derivative_.middleware, std::move(callback), |
| ForwardParameter(std::forward<Args>(args))...); |
| |
| middleware_derivative_.task_runner->PostTask(FROM_HERE, std::move(task)); |
| event.Wait(); |
| return result; |
| } |
| |
| // Calling synchronous backend function asynchronously. |
| template <auto Func, typename... Args> |
| requires(SyncBackendMethod<decltype(Func)>) |
| static void CallAsyncInternal(base::WeakPtr<MiddlewareOwner> middleware, |
| SubClassCallback<decltype(Func)> callback, |
| Args... args) { |
| std::move(callback).Run(DoSyncBackendCall<Func, Args...>( |
| std::move(middleware), ForwardParameter(std::move(args))...)); |
| } |
| |
| // Calling asynchronous backend function asynchronously. |
| template <auto Func, typename... Args> |
| requires(AsyncBackendMethod<decltype(Func)>) |
| static void CallAsyncInternal(base::WeakPtr<MiddlewareOwner> middleware, |
| SubClassCallback<decltype(Func)> callback, |
| Args... args) { |
| Middleware::DoAsyncBackendCall<Func, Args...>( |
| std::move(middleware), std::move(callback), |
| ForwardParameter(std::move(args))...); |
| } |
| |
| template <typename Result> |
| requires(std::constructible_from<Result, Status>) |
| static void TrackFuncResult(const std::string& function_name, |
| Metrics* metrics, |
| Result& result) { |
| using hwsec_foundation::status::MakeStatus; |
| |
| std::string sim_name = SimplifyFuncName(function_name); |
| |
| if (metrics) { |
| metrics->SendFuncResultToUMA(sim_name, result.status()); |
| } |
| |
| if (!result.ok()) { |
| Status status = std::move(result).err_status(); |
| TPMRetryAction action = status->ToTPMRetryAction(); |
| status = MakeStatus<TPMError>( |
| base::StringPrintf("%s(%s)", sim_name.c_str(), |
| GetTPMRetryActionName(action)), |
| action) |
| .Wrap(std::move(status)); |
| result = std::move(status); |
| } |
| } |
| |
| static scoped_refptr<base::TaskRunner> GetReplyRunner() { |
| CHECK(base::SequencedTaskRunner::HasCurrentDefault()); |
| return base::SequencedTaskRunner::GetCurrentDefault(); |
| } |
| |
| MiddlewareDerivative middleware_derivative_; |
| }; |
| |
| } // namespace hwsec |
| |
| #endif // LIBHWSEC_MIDDLEWARE_MIDDLEWARE_H_ |