blob: 0da57dbd7563a140a33afeee1af44aa603137552 [file] [log] [blame]
// Copyright 2022 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "faced/enrollment_session.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
#include <absl/status/status.h>
#include <base/functional/bind.h>
#include <base/location.h>
#include <base/task/sequenced_task_runner.h>
#include "faced/common/face_status.h"
#include "faced/util/task.h"
namespace faced {
using ::chromeos::faceauth::mojom::EnrollmentCompleteMessage;
using ::chromeos::faceauth::mojom::EnrollmentCompleteMessagePtr;
using ::chromeos::faceauth::mojom::EnrollmentSessionConfigPtr;
using ::chromeos::faceauth::mojom::EnrollmentUpdateMessage;
using ::chromeos::faceauth::mojom::EnrollmentUpdateMessagePtr;
using ::chromeos::faceauth::mojom::FaceEnrollmentSessionDelegate;
using ::chromeos::faceauth::mojom::FaceOperationStatus;
using ::chromeos::faceauth::mojom::SessionError;
using ::faceauth::eora::AbortEnrollmentRequest;
using ::faceauth::eora::AbortEnrollmentResponse;
using ::faceauth::eora::CompleteEnrollmentRequest;
using ::faceauth::eora::CompleteEnrollmentResponse;
using ::faceauth::eora::ProcessFrameForEnrollmentRequest;
using ::faceauth::eora::ProcessFrameForEnrollmentResponse;
using ::faceauth::eora::StartEnrollmentRequest;
using ::faceauth::eora::StartEnrollmentResponse;
namespace {
faceauth::eora::FrameType FormatToFrameType(Frame::Format format) {
switch (format) {
case Frame::Format::kMjpeg:
return faceauth::eora::FrameType::MJPG;
case Frame::Format::kYuvNv12:
return faceauth::eora::FrameType::YUV_NV12;
default:
return faceauth::eora::FrameType::UNKNOWN;
}
}
faceauth::eora::CameraFrame FrameToCameraFrame(std::unique_ptr<Frame> frame) {
faceauth::eora::CameraFrame result;
result.set_width(frame->width);
result.set_height(frame->height);
result.set_type(FormatToFrameType(frame->format));
*result.mutable_payload() = std::move(frame->data);
return result;
}
} // namespace
absl::StatusOr<std::unique_ptr<EnrollmentSession>> EnrollmentSession::Create(
absl::BitGen& bitgen,
mojo::PendingReceiver<FaceEnrollmentSession> receiver,
mojo::PendingRemote<FaceEnrollmentSessionDelegate> delegate,
EnrollmentSessionConfigPtr config,
Lease<brillo::AsyncGrpcClient<faceauth::eora::FaceService>> client,
std::unique_ptr<StreamReader<InputFrame>> stream_reader) {
uint64_t session_id = GenerateSessionId(bitgen);
// Using `new` to access private constructor of `EnrollmentSession`.
std::unique_ptr<EnrollmentSession> session(new EnrollmentSession(
session_id, std::move(receiver), std::move(delegate), std::move(client),
std::move(stream_reader)));
session->delegate_.set_disconnect_handler(base::BindOnce(
&EnrollmentSession::TryCancel, base::Unretained(session.get())));
session->receiver_.set_disconnect_handler(base::BindOnce(
&EnrollmentSession::TryCancel, base::Unretained(session.get())));
return session;
}
EnrollmentSession::EnrollmentSession(
uint64_t session_id,
mojo::PendingReceiver<FaceEnrollmentSession> receiver,
mojo::PendingRemote<FaceEnrollmentSessionDelegate> delegate,
Lease<brillo::AsyncGrpcClient<faceauth::eora::FaceService>> client,
std::unique_ptr<StreamReader<InputFrame>> stream_reader)
: session_id_(session_id),
receiver_(this, std::move(receiver)),
delegate_(std::move(delegate)),
rpc_client_(std::move(client)),
stream_reader_(std::move(stream_reader)) {}
void EnrollmentSession::NotifyUpdate(FaceOperationStatus status) {
EnrollmentUpdateMessagePtr message(
EnrollmentUpdateMessage::New(status, /*poses=*/std::vector<bool>()));
delegate_->OnEnrollmentUpdate(std::move(message));
}
void EnrollmentSession::NotifyComplete() {
delegate_->OnEnrollmentComplete(EnrollmentCompleteMessage::New());
FinishSession(absl::OkStatus());
}
void EnrollmentSession::NotifyCancelled() {
delegate_->OnEnrollmentCancelled();
FinishSession(absl::CancelledError());
}
void EnrollmentSession::NotifyError(absl::Status error) {
// TODO(bkersten): map absl::Status to SessionError
SessionError session_error = SessionError::UNKNOWN;
delegate_->OnEnrollmentError(session_error);
FinishSession(error);
}
void EnrollmentSession::Start(StartCallback start_callback,
CompletionCallback completion_callback) {
completion_callback_ = std::move(completion_callback);
(*rpc_client_)
->CallRpc(
&faceauth::eora::FaceService::Stub::AsyncStartEnrollment,
StartEnrollmentRequest(),
base::BindOnce(&EnrollmentSession::CompleteStartEnrollment,
base::Unretained(this), std::move(start_callback)));
}
void EnrollmentSession::CompleteStartEnrollment(
StartCallback callback,
grpc::Status status,
std::unique_ptr<StartEnrollmentResponse> response) {
// Ensure the StartEnrollment RPC succeeded.
if (!status.ok()) {
FinishSession(absl::UnavailableError(status.error_message()));
return;
}
absl::Status rpc_status = ToAbslStatus(response->status());
if (!rpc_status.ok()) {
FinishSession(rpc_status);
return;
}
PostToCurrentSequence(std::move(callback));
// Begin processing frames.
stream_reader_->Read(base::BindOnce(
&EnrollmentSession::ProcessAvailableFrame, base::Unretained(this),
base::BindOnce(&EnrollmentSession::CompleteProcessFrame,
base::Unretained(this))));
}
void EnrollmentSession::ProcessAvailableFrame(ProcessFrameCallback callback,
StreamValue<InputFrame> frame) {
if (abort_requested_) {
AbortEnrollment(base::BindOnce(&EnrollmentSession::CompleteCancelEnrollment,
base::Unretained(this)));
return;
}
CHECK(frame.value != std::nullopt) << "Camera stream unexpectedly closed";
if (!frame.value->ok()) {
AbortEnrollment(base::BindOnce(&EnrollmentSession::CompleteAbortEnrollment,
base::Unretained(this),
frame.value->status()));
return;
}
// Convert input frame to CameraFrame for request.
ProcessFrameForEnrollmentRequest request;
std::unique_ptr<Frame> frame_ptr = std::move(*frame.value.value());
*request.mutable_frame() = FrameToCameraFrame(std::move(frame_ptr));
(*rpc_client_)
->CallRpc(
&faceauth::eora::FaceService::Stub::AsyncProcessFrameForEnrollment,
request, std::move(callback));
}
void EnrollmentSession::CompleteProcessFrame(
grpc::Status status,
std::unique_ptr<ProcessFrameForEnrollmentResponse> response) {
if (!status.ok()) {
NotifyError(absl::UnavailableError(status.error_message()));
return;
}
absl::Status rpc_status = ToAbslStatus(response->status());
if (!rpc_status.ok()) {
AbortEnrollment(base::BindOnce(&EnrollmentSession::CompleteAbortEnrollment,
base::Unretained(this), rpc_status));
return;
}
if (abort_requested_) {
AbortEnrollment(base::BindOnce(&EnrollmentSession::CompleteCancelEnrollment,
base::Unretained(this)));
return;
}
if (response->enrollment_completed()) {
CompleteEnrollment(
base::BindOnce(&EnrollmentSession::CompleteCompleteEnrollment,
base::Unretained(this)));
return;
}
// Proceed to process the next frame.
stream_reader_->Read(base::BindOnce(
&EnrollmentSession::ProcessAvailableFrame, base::Unretained(this),
base::BindOnce(&EnrollmentSession::CompleteProcessFrame,
base::Unretained(this))));
}
void EnrollmentSession::CompleteEnrollment(CompleteCallback callback) {
(*rpc_client_)
->CallRpc(&faceauth::eora::FaceService::Stub::AsyncCompleteEnrollment,
CompleteEnrollmentRequest(), std::move(callback));
}
void EnrollmentSession::CompleteCompleteEnrollment(
grpc::Status status, std::unique_ptr<CompleteEnrollmentResponse> response) {
if (!status.ok()) {
NotifyError(absl::UnavailableError(status.error_message()));
return;
}
absl::Status rpc_status = ToAbslStatus(response->status());
if (!rpc_status.ok()) {
NotifyError(rpc_status);
return;
}
NotifyComplete();
}
void EnrollmentSession::TryCancel() {
abort_requested_ = true;
stream_reader_->Close();
}
void EnrollmentSession::AbortEnrollment(AbortCallback callback) {
(*rpc_client_)
->CallRpc(&faceauth::eora::FaceService::Stub::AsyncAbortEnrollment,
AbortEnrollmentRequest(), std::move(callback));
}
void EnrollmentSession::CompleteAbortEnrollment(
absl::Status error,
grpc::Status status,
std::unique_ptr<AbortEnrollmentResponse> response) {
// gRPC and abort operation errors are not propagated back to the client.
// Instead the client will be notified of the original error.
if (!status.ok()) {
LOG(WARNING) << status.error_message();
}
absl::Status rpc_status = ToAbslStatus(response->status());
if (!rpc_status.ok()) {
LOG(WARNING) << rpc_status.message();
}
NotifyError(error);
}
void EnrollmentSession::CompleteCancelEnrollment(
grpc::Status status, std::unique_ptr<AbortEnrollmentResponse> response) {
// gRPC and abort operation errors are not propagated back to the client.
// Instead the client will be notified of cancellation.
if (!status.ok()) {
LOG(WARNING) << status.error_message();
}
absl::Status rpc_status = ToAbslStatus(response->status());
if (!rpc_status.ok()) {
LOG(WARNING) << rpc_status.message();
}
NotifyCancelled();
}
void EnrollmentSession::FinishSession(absl::Status status) {
// Close the connections to the enrollment session interfaces.
delegate_.reset();
receiver_.reset();
PostToCurrentSequence(
base::BindOnce(std::move(completion_callback_), status));
}
} // namespace faced