blob: 141bea7ad1089a3d1be4218fabc98ce412f7fc7f [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// A simplified interface to the ML service. Used to implement the ml_cmdline
// tool.
#include "ml/simple.h"
#include <string>
#include <utility>
#include <vector>
#include <base/bind.h>
#include <base/run_loop.h>
#include <mojo/public/cpp/bindings/remote.h>
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/graph_executor.mojom.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/mojom/model.mojom.h"
#include "ml/tensor_view.h"
using ::chromeos::machine_learning::mojom::BuiltinModelId;
using ::chromeos::machine_learning::mojom::BuiltinModelSpec;
using ::chromeos::machine_learning::mojom::BuiltinModelSpecPtr;
using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
using ::chromeos::machine_learning::mojom::ExecuteResult;
using ::chromeos::machine_learning::mojom::GraphExecutor;
using ::chromeos::machine_learning::mojom::GraphExecutorOptions;
using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::MachineLearningService;
using ::chromeos::machine_learning::mojom::Model;
using ::chromeos::machine_learning::mojom::TensorPtr;
namespace ml {
namespace simple {
namespace {
// Creates a 1-D tensor containing a single value
TensorPtr NewSingleValueTensor(const double value) {
auto tensor(chromeos::machine_learning::mojom::Tensor::New());
TensorView<double> tensor_view(tensor);
tensor_view.Allocate();
tensor_view.GetShape() = {1};
tensor_view.GetValues() = {value};
return tensor;
}
} // namespace
AddResult Add(const double x, const double y, const bool use_nnapi) {
AddResult result = {"Not completed.", -1.0};
// Create ML Service
mojo::Remote<MachineLearningService> ml_service;
const MachineLearningServiceImpl ml_service_impl(
ml_service.BindNewPipeAndPassReceiver().PassPipe(), base::Closure());
// Load model.
BuiltinModelSpecPtr spec = BuiltinModelSpec::New();
spec->id = BuiltinModelId::TEST_MODEL;
mojo::Remote<Model> model;
bool model_load_ok = false;
ml_service->LoadBuiltinModel(
std::move(spec), model.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* const model_load_ok, const LoadModelResult result) {
*model_load_ok = result == LoadModelResult::OK;
},
&model_load_ok));
base::RunLoop().RunUntilIdle();
if (!model_load_ok) {
result.status = "Failed to load model.";
return result;
}
// Get graph executor for model.
mojo::Remote<GraphExecutor> graph_executor;
bool graph_executor_ok = false;
auto options = GraphExecutorOptions::New(use_nnapi);
model->CreateGraphExecutorWithOptions(
std::move(options), graph_executor.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* const graph_executor_ok,
const CreateGraphExecutorResult result) {
*graph_executor_ok = result == CreateGraphExecutorResult::OK;
},
&graph_executor_ok));
base::RunLoop().RunUntilIdle();
if (!model_load_ok) {
result.status = "Failed to get graph executor";
return result;
}
// Construct input to graph executor and perform inference
base::flat_map<std::string, TensorPtr> inputs;
inputs.emplace("x", NewSingleValueTensor(x));
inputs.emplace("y", NewSingleValueTensor(y));
std::vector<std::string> outputs({"z"});
bool inference_ok = false;
graph_executor->Execute(
std::move(inputs), std::move(outputs),
base::Bind(
[](bool* const inference_ok, double* const sum,
const ExecuteResult execute_result,
base::Optional<std::vector<TensorPtr>> outputs) {
// Check that the inference succeeded and gave the expected number
// of outputs.
*inference_ok = execute_result == ExecuteResult::OK &&
outputs.has_value() && outputs->size() == 1;
if (!*inference_ok) {
return;
}
// Get value from output
const TensorView<double> out_tensor((*outputs)[0]);
*sum = out_tensor.GetValues()[0];
},
&inference_ok, &result.sum));
base::RunLoop().RunUntilIdle();
if (!inference_ok) {
result.status = "Inference failed.";
return result;
}
result.status = "OK";
return result;
}
} // namespace simple
} // namespace ml