blob: b58d9cd532a67e159b2b5054d34f54a85ec64ce3 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <base/callback_forward.h>
#include <base/containers/flat_map.h>
#include <base/macros.h>
#include <mojo/public/cpp/bindings/pending_receiver.h>
#include <mojo/public/cpp/bindings/receiver.h>
#include <tensorflow/lite/model.h>
#include "ml/mojom/graph_executor.mojom.h"
namespace ml {
// Allows execution of TensorFlow lite graphs using input / output specified
// with Mojo types.
// Holds as little state as possible (with the remainder living in the parent
// Model object and shared between all sibling GraphExecutors). Hence, a
// GraphExecutor becomes invalid when its parent Model object is destroyed.
// A given GraphExecutorImpl may not be used concurrently from different
// sequences.
class GraphExecutorImpl
: public chromeos::machine_learning::mojom::GraphExecutor {
// Creates an instance bound to `receiver`.
// The `required_inputs` and `required_outputs` arguments specify a mapping
// from required input / output tensor names to their indices in the TF lite
// graph, and must outlive this object.
// UMA metrics will be logged with the specified `metrics_model_name`.
// As is standard, `interpreter` must outlive the model with which it was
// constructed.
const std::map<std::string, int>& required_inputs,
const std::map<std::string, int>& required_outputs,
std::unique_ptr<tflite::Interpreter> interpreter,
const std::string& metrics_model_name);
GraphExecutorImpl(const GraphExecutorImpl&) = delete;
GraphExecutorImpl& operator=(const GraphExecutorImpl&) = delete;
void set_disconnect_handler(base::Closure disconnect_handler);
// chromeos::machine_learning::mojom::GraphExecutor:
void Execute(
base::flat_map<std::string, chromeos::machine_learning::mojom::TensorPtr>
const std::vector<std::string>& output_names,
ExecuteCallback callback);
const std::map<std::string, int>& required_inputs_;
const std::map<std::string, int>& required_outputs_;
const std::unique_ptr<tflite::Interpreter> interpreter_;
mojo::Receiver<chromeos::machine_learning::mojom::GraphExecutor> receiver_;
// Model name as it should appear in UMA histogram names.
const std::string metrics_model_name_;
} // namespace ml