blob: d21ee71f11a30bd16cbfc98feb0f7b36ac5d340a [file] [log] [blame] [edit]
// Copyright 2024 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/heatmap_processor.h"
#include <map>
#include <string>
#include <utility>
#include <base/logging.h>
#include "ml/request_metrics.h"
using ::chromeos::machine_learning::mojom::ExecuteResult;
using ::chromeos::machine_learning::mojom::FloatList;
using ::chromeos::machine_learning::mojom::GpuDelegateApi;
using ::chromeos::machine_learning::mojom::HeatmapPalmRejectionClient;
using ::chromeos::machine_learning::mojom::HeatmapPalmRejectionConfigPtr;
using ::chromeos::machine_learning::mojom::Int64List;
using ::chromeos::machine_learning::mojom::LoadHeatmapPalmRejectionResult;
using ::chromeos::machine_learning::mojom::Tensor;
using ::chromeos::machine_learning::mojom::TensorPtr;
using ::chromeos::machine_learning::mojom::ValueList;
namespace ml {
namespace {
// TFLite graph node names:
constexpr char kInputNodeName[] = "input";
constexpr char kOutputNodeName[] = "output";
// Base name for UMA metrics related to graph execution
constexpr char kMetricsRequestName[] = "ExecuteResult";
} // namespace
HeatmapProcessor::HeatmapProcessor() = default;
LoadHeatmapPalmRejectionResult HeatmapProcessor::Start(
mojo::PendingRemote<HeatmapPalmRejectionClient> client,
HeatmapPalmRejectionConfigPtr config) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ready_ = false;
// Attempt to load model.
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(config->tf_model_path.c_str());
if (model == nullptr) {
LOG(ERROR) << "Failed to load model file '" << config->tf_model_path
<< "'.";
return LoadHeatmapPalmRejectionResult::LOAD_MODEL_ERROR;
}
palm_threshold_ = config->palm_threshold;
model_delegate_ = std::make_unique<ModelDelegate>(
std::map<std::string, int>({{kInputNodeName, config->input_node}}),
std::map<std::string, int>({{kOutputNodeName, config->output_node}}),
std::move(model), "PonchoPalmRejectionModel");
std::unique_ptr<GraphExecutorDelegate> graph_executor_delegate;
if (model_delegate_->CreateGraphExecutorDelegate(
false, false, GpuDelegateApi::UNKNOWN, &graph_executor_delegate) !=
CreateGraphExecutorResult::OK) {
LOG(ERROR) << "Failed to create graph executor";
return LoadHeatmapPalmRejectionResult::CREATE_GRAPH_EXECUTOR_ERROR;
}
graph_executor_delegate_ = std::move(graph_executor_delegate);
client_ = mojo::Remote(std::move(client));
ready_ = true;
return LoadHeatmapPalmRejectionResult::OK;
}
void HeatmapProcessor::Process(const std::vector<double>& heatmap_data,
int height,
int width,
base::Time timestamp) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!ready_) {
return;
}
RequestMetrics request_metrics("HeatmapPalmRejection", kMetricsRequestName);
request_metrics.StartRecordingPerformanceMetrics();
base::flat_map<std::string, TensorPtr> inputs;
auto tensor = Tensor::New();
tensor->shape = Int64List::New();
tensor->shape->value = std::vector<int64_t>{1, height, width, 1};
tensor->data = ValueList::NewFloatList(FloatList::New(heatmap_data));
inputs.emplace(kInputNodeName, std::move(tensor));
std::vector<std::string> outputs({kOutputNodeName});
std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
auto result = graph_executor_delegate_->Execute(std::move(inputs), outputs,
output_tensors);
request_metrics.FinishRecordingPerformanceMetrics();
bool is_palm = false;
if (result == ExecuteResult::OK) {
auto& output_data = output_tensors[0]->data;
if (output_data->is_float_list() &&
output_data->get_float_list()->value.size() == 1) {
double prediction = output_data->get_float_list()->value[0];
is_palm = prediction > palm_threshold_;
request_metrics.RecordRequestEvent(ExecuteResult::OK);
} else {
request_metrics.RecordRequestEvent(ExecuteResult::OUTPUT_MISSING_ERROR);
LOG(ERROR)
<< "Heatmap palm rejection model returns unexpected output data";
}
} else {
request_metrics.RecordRequestEvent(result);
LOG(ERROR) << "Heatmap palm rejection model execution failed with error "
<< result;
}
ReportResult(is_palm, timestamp);
}
void HeatmapProcessor::ReportResult(bool is_palm, base::Time timestamp) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto event = chromeos::machine_learning::mojom::HeatmapProcessedEvent::New();
event->timestamp = timestamp;
event->is_palm = is_palm;
client_->OnHeatmapProcessedEvent(std::move(event));
}
HeatmapProcessor* HeatmapProcessor::GetInstance() {
static base::NoDestructor<HeatmapProcessor> instance;
return instance.get();
}
} // namespace ml