blob: b34a468d0b42e0377bc69522c33a09c2c822e665 [file] [log] [blame] [edit]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef ML_DBUS_SERVICE_TF_MODEL_GRAPH_EXECUTOR_H_
#define ML_DBUS_SERVICE_TF_MODEL_GRAPH_EXECUTOR_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "chrome/knowledge/assist_ranker/example_preprocessor.pb.h"
#include "chrome/knowledge/assist_ranker/ranker_example.pb.h"
#include "ml/graph_executor_delegate.h"
#include "ml/model_delegate.h"
#include "ml/model_metadata.h"
namespace ml {
using ::chromeos::machine_learning::mojom::BuiltinModelId;
using ::chromeos::machine_learning::mojom::TensorPtr;
// TfModelGraphExecutor is responsible for the real machine learning related
// jobs, including vectorizing feature example and tflite model inference.
// It's initialized with a model and a perprocessor config. Currently it only
// supports loading builtin model.
class TfModelGraphExecutor {
public:
TfModelGraphExecutor(BuiltinModelId model_id,
const std::string& preprocessor_file_name);
TfModelGraphExecutor(const TfModelGraphExecutor&) = delete;
TfModelGraphExecutor& operator=(const TfModelGraphExecutor&) = delete;
~TfModelGraphExecutor();
// Whether the object is ready to do inference. Must be verified to be true
// before calling Execute.
bool Ready() const;
// Preprocesses `example` with `config_` and runs the tensorflow graph with
// the vectorized features, extracts the outputs and appends to
// `output_tensors`.
bool Execute(bool clear_other_features,
assist_ranker::RankerExample* example,
std::vector<TensorPtr>* output_tensors) const;
// Creates an instance with given asset_dir for testing.
static std::unique_ptr<TfModelGraphExecutor> CreateForTesting(
BuiltinModelId model_id,
const std::string& preprocessor_file_name,
const std::string& asset_dir);
private:
// Constructor that allows overriding of the asset dir.
TfModelGraphExecutor(BuiltinModelId model_id,
const std::string& preprocessor_file_name,
const std::string& asset_dir);
const std::string asset_dir_;
std::vector<std::string> output_names_;
std::unique_ptr<assist_ranker::ExamplePreprocessorConfig> config_;
std::unique_ptr<ModelDelegate> model_delegate_;
std::unique_ptr<GraphExecutorDelegate> graph_executor_delegate_;
};
} // namespace ml
#endif // ML_DBUS_SERVICE_TF_MODEL_GRAPH_EXECUTOR_H_