blob: 89a13d9fd75fd248a411ce3206e76d5565787f00 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef ML_GRAPH_EXECUTOR_IMPL_H_
#define ML_GRAPH_EXECUTOR_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include <base/callback_forward.h>
#include <base/containers/flat_map.h>
#include <base/macros.h>
#include <mojo/public/cpp/bindings/pending_receiver.h>
#include <mojo/public/cpp/bindings/receiver.h>
#include "ml/graph_executor_delegate.h"
#include "ml/mojom/graph_executor.mojom.h"
namespace ml {
// Allows execution of TensorFlow lite graphs using input / output specified
// with Mojo types.
//
// Holds as little state as possible (with the remainder living in the parent
// Model object and shared between all sibling GraphExecutors). Hence, a
// GraphExecutor becomes invalid when its parent Model object is destroyed.
//
// A given GraphExecutorImpl may not be used concurrently from different
// sequences.
// Example usage:
// auto interpreter = std::make_unique<tflite::Interpreter>();
// const std::string metrics_model_name = xxx;
// mojo::Remote<GraphExecutor> graph_executor;
// const GraphExecutorImpl graph_executor_impl(
// std::make_unique<GraphExecutorDelegate>(
// input_names, output_names, std::move(graph_executor_delegate),
// metrics_model_name),
// graph_executor.BindNewPipeAndPassReceiver());
class GraphExecutorImpl
: public chromeos::machine_learning::mojom::GraphExecutor {
public:
// Takes ownership of `graph_executor_delegate` to do the actual work of
// calling TFlite, and creates an instance bound to `receiver`.
GraphExecutorImpl(
std::unique_ptr<GraphExecutorDelegate> graph_executor_delegate,
mojo::PendingReceiver<chromeos::machine_learning::mojom::GraphExecutor>
receiver);
GraphExecutorImpl(const GraphExecutorImpl&) = delete;
GraphExecutorImpl& operator=(const GraphExecutorImpl&) = delete;
void set_disconnect_handler(base::Closure disconnect_handler);
private:
// chromeos::machine_learning::mojom::GraphExecutor:
void Execute(
base::flat_map<std::string, chromeos::machine_learning::mojom::TensorPtr>
inputs,
const std::vector<std::string>& output_names,
ExecuteCallback callback);
// The delegate that actually runs TFLite graph.
std::unique_ptr<GraphExecutorDelegate> graph_executor_delegate_;
mojo::Receiver<chromeos::machine_learning::mojom::GraphExecutor> receiver_;
};
} // namespace ml
#endif // ML_GRAPH_EXECUTOR_IMPL_H_