blob: 29c007048dff55afe24a15524cfbd1da5267d561 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/graph_executor_impl.h"
#include <utility>
#include <base/stl_util.h>
#include "ml/mojom/tensor.mojom.h"
namespace ml {
namespace {
using ::chromeos::machine_learning::mojom::ExecuteResult;
using ::chromeos::machine_learning::mojom::GraphExecutor;
using ::chromeos::machine_learning::mojom::TensorPtr;
} // namespace
GraphExecutorImpl::GraphExecutorImpl(
std::unique_ptr<GraphExecutorDelegate> graph_executor_delegate,
mojo::PendingReceiver<GraphExecutor> receiver)
: graph_executor_delegate_(std::move(graph_executor_delegate)),
receiver_(this, std::move(receiver)) {}
void GraphExecutorImpl::set_disconnect_handler(
base::Closure disconnect_handler) {
receiver_.set_disconnect_handler(std::move(disconnect_handler));
}
void GraphExecutorImpl::Execute(base::flat_map<std::string, TensorPtr> tensors,
const std::vector<std::string>& outputs,
ExecuteCallback callback) {
std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors;
auto result = graph_executor_delegate_->Execute(std::move(tensors), outputs,
output_tensors);
if (result != ExecuteResult::OK) {
std::move(callback).Run(result, base::nullopt);
} else {
std::move(callback).Run(result, std::move(output_tensors));
}
}
} // namespace ml