| // Copyright 2018 The Chromium OS Authors. All rights reserved. | 
 | // Use of this source code is governed by a BSD-style license that can be | 
 | // found in the LICENSE file. | 
 |  | 
 | #include "ml/graph_executor_impl.h" | 
 | #include "ml/request_metrics.h" | 
 |  | 
 | #include <set> | 
 | #include <utility> | 
 | #include <vector> | 
 |  | 
 | #include <base/stl_util.h> | 
 |  | 
 | #include "ml/mojom/tensor.mojom.h" | 
 | #include "ml/tensor_view.h" | 
 |  | 
 | namespace ml { | 
 |  | 
 | namespace { | 
 |  | 
 | using ::chromeos::machine_learning::mojom::ExecuteResult; | 
 | using ::chromeos::machine_learning::mojom::GraphExecutor; | 
 | using ::chromeos::machine_learning::mojom::Int64List; | 
 | using ::chromeos::machine_learning::mojom::Tensor; | 
 | using ::chromeos::machine_learning::mojom::TensorPtr; | 
 | using ::chromeos::machine_learning::mojom::ValueList; | 
 |  | 
 | // Base name for UMA metrics related to graph execution | 
 | constexpr char kMetricsRequestName[] = "ExecuteResult"; | 
 |  | 
 | // Verifies `tensor` is valid (i.e. is of type `TensorType` and of the correct | 
 | // shape for this input) and copies its data into the graph `interpreter` at | 
 | // position `index`. | 
 | template <typename TensorType, typename MemoryType> | 
 | ExecuteResult PopulateInput(const TensorPtr& tensor, | 
 |                             const int index, | 
 |                             tflite::Interpreter* const interpreter) { | 
 |   const TensorView<TensorType> tensor_view(tensor); | 
 |  | 
 |   if (!tensor_view.IsValidType()) | 
 |     return ExecuteResult::INPUT_TYPE_ERROR; | 
 |  | 
 |   if (!tensor_view.IsValidFormat()) | 
 |     return ExecuteResult::INPUT_FORMAT_ERROR; | 
 |  | 
 |   // Check that given input shape matches that expected by TF lite. | 
 |  | 
 |   const TfLiteIntArray& expected_dims = *interpreter->tensor(index)->dims; | 
 |   const std::vector<int64_t>& actual_dims = tensor_view.GetShape(); | 
 |  | 
 |   bool shape_matches = expected_dims.size == actual_dims.size(); | 
 |   for (int i = 0; shape_matches && i < expected_dims.size; ++i) { | 
 |     shape_matches = expected_dims.data[i] == actual_dims[i]; | 
 |   } | 
 |  | 
 |   if (!shape_matches) | 
 |     return ExecuteResult::INPUT_SHAPE_ERROR; | 
 |  | 
 |   MemoryType* const input_memory = interpreter->typed_tensor<MemoryType>(index); | 
 |   const std::vector<TensorType>& tensor_values = tensor_view.GetValues(); | 
 |   for (int i = 0; i < tensor_values.size(); ++i) { | 
 |     input_memory[i] = tensor_values[i]; | 
 |   } | 
 |  | 
 |   return ExecuteResult::OK; | 
 | } | 
 |  | 
 | ExecuteResult InvalidInput(const TensorPtr&, int, tflite::Interpreter*) { | 
 |   return ExecuteResult::EXECUTION_ERROR; | 
 | } | 
 |  | 
 | // A table of functions to validate / populate data for model nodes expecting | 
 | // input of each TF lite type. | 
 | // | 
 | // This table is indexed by TfLiteType, the possible values of which can be | 
 | // found at <tensorflow/lite/context.h>. We make the following | 
 | // assumptions about index values: | 
 | //   1) They will remain consistent across TF lite releases, and | 
 | //   2) They will always start from (close to) 0 and be (mostly) consecutive. | 
 | // | 
 | // Since TfLiteType is part of the stable C API for TF lite, these assumptions | 
 | // seem fair. | 
 | constexpr decltype(&InvalidInput) kPopulateInputFns[] = { | 
 |     &InvalidInput,                     // kTfLiteNoType | 
 |     &PopulateInput<double, float>,     // kTfLiteFloat32 | 
 |     &PopulateInput<int64_t, int32_t>,  // kTfLiteInt32 | 
 |     &PopulateInput<int64_t, uint8_t>,  // kTfLiteUInt8 | 
 |     &PopulateInput<int64_t, int64_t>,  // kTfLiteInt64 | 
 |     &InvalidInput,                     // kTfLiteString | 
 |     &PopulateInput<int64_t, bool>,     // kTfLiteBool | 
 | }; | 
 |  | 
 | // Copies data from position `index` in the graph `interpreter` into the given | 
 | // tensor object. | 
 | template <typename TensorType, typename MemoryType> | 
 | ExecuteResult PopulateOutput(const int index, | 
 |                              const tflite::Interpreter& interpreter, | 
 |                              const TensorPtr& tensor) { | 
 |   TensorView<TensorType> tensor_view(tensor); | 
 |   tensor_view.Allocate(); | 
 |  | 
 |   // Empty output is not valid. | 
 |   const TfLiteIntArray& dims = *interpreter.tensor(index)->dims; | 
 |   if (dims.size == 0) | 
 |     return ExecuteResult::EXECUTION_ERROR; | 
 |  | 
 |   // Copy across size information and calculate the number of elements being | 
 |   // output. | 
 |   int64_t num_entries = 1; | 
 |   std::vector<int64_t>& tensor_dims = tensor_view.GetShape(); | 
 |   tensor_dims.resize(dims.size); | 
 |   for (int i = 0; i < dims.size; ++i) { | 
 |     const int64_t dim_length = dims.data[i]; | 
 |  | 
 |     if (dim_length <= 0) | 
 |       return ExecuteResult::EXECUTION_ERROR; | 
 |  | 
 |     tensor_dims[i] = dim_length; | 
 |     num_entries *= dim_length; | 
 |   } | 
 |  | 
 |   // Populate tensor values. | 
 |   const MemoryType* const output_memory = | 
 |       interpreter.typed_tensor<MemoryType>(index); | 
 |   std::vector<TensorType>& tensor_values = tensor_view.GetValues(); | 
 |   tensor_values.resize(num_entries); | 
 |   for (int i = 0; i < num_entries; ++i) { | 
 |     tensor_values[i] = output_memory[i]; | 
 |   } | 
 |  | 
 |   return ExecuteResult::OK; | 
 | } | 
 |  | 
 | ExecuteResult InvalidOutput(int, const tflite::Interpreter&, const TensorPtr&) { | 
 |   return ExecuteResult::EXECUTION_ERROR; | 
 | } | 
 |  | 
 | // A table of functions to populate data for tensors from output of each TF lite | 
 | // type. | 
 | // | 
 | // This table is indexed by TfLiteType, the possible values of which can be | 
 | // found at <tensorflow/lite/context.h>. See the caveats discussed in | 
 | // the comment above `kPopulateInputFns`. | 
 | constexpr decltype(&InvalidOutput) kPopulateOutputFns[] = { | 
 |     &InvalidOutput,                     // kTfLiteNoType | 
 |     &PopulateOutput<double, float>,     // kTfLiteFloat32 | 
 |     &PopulateOutput<int64_t, int32_t>,  // kTfLiteInt32 | 
 |     &PopulateOutput<int64_t, uint8_t>,  // kTfLiteUInt8 | 
 |     &PopulateOutput<int64_t, int64_t>,  // kTfLiteInt64 | 
 |     &InvalidOutput,                     // kTfLiteString | 
 |     &PopulateOutput<int64_t, bool>,     // kTfLiteBool | 
 | }; | 
 |  | 
 | }  // namespace | 
 |  | 
 | GraphExecutorImpl::GraphExecutorImpl( | 
 |     const std::map<std::string, int>& required_inputs, | 
 |     const std::map<std::string, int>& required_outputs, | 
 |     std::unique_ptr<tflite::Interpreter> interpreter, | 
 |     mojo::PendingReceiver<GraphExecutor> receiver, | 
 |     const std::string& metrics_model_name) | 
 |     : required_inputs_(required_inputs), | 
 |       required_outputs_(required_outputs), | 
 |       interpreter_(std::move(interpreter)), | 
 |       receiver_(this, std::move(receiver)), | 
 |       metrics_model_name_(metrics_model_name) {} | 
 |  | 
 | void GraphExecutorImpl::set_disconnect_handler( | 
 |     base::Closure disconnect_handler) { | 
 |   receiver_.set_disconnect_handler(std::move(disconnect_handler)); | 
 | } | 
 |  | 
 | void GraphExecutorImpl::Execute(base::flat_map<std::string, TensorPtr> tensors, | 
 |                                 const std::vector<std::string>& outputs, | 
 |                                 ExecuteCallback callback) { | 
 |   DCHECK(!metrics_model_name_.empty()); | 
 |  | 
 |   RequestMetrics request_metrics(metrics_model_name_, kMetricsRequestName); | 
 |   request_metrics.StartRecordingPerformanceMetrics(); | 
 |  | 
 |   // Validate input and output names (before executing graph, for efficiency). | 
 |  | 
 |   for (const auto& kv : tensors) { | 
 |     const std::string& cur_input_name = kv.first; | 
 |  | 
 |     const auto name_lookup = required_inputs_.find(cur_input_name); | 
 |     if (name_lookup == required_inputs_.end() || | 
 |         name_lookup->second >= interpreter_->tensors_size()) { | 
 |       std::move(callback).Run(ExecuteResult::UNKNOWN_INPUT_ERROR, | 
 |                               base::nullopt); | 
 |       request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_INPUT_ERROR); | 
 |       return; | 
 |     } | 
 |   } | 
 |   if (tensors.size() != required_inputs_.size()) { | 
 |     std::move(callback).Run(ExecuteResult::INPUT_MISSING_ERROR, base::nullopt); | 
 |     request_metrics.RecordRequestEvent(ExecuteResult::INPUT_MISSING_ERROR); | 
 |     return; | 
 |   } | 
 |  | 
 |   std::set<std::string> seen_outputs; | 
 |   for (const auto& cur_output_name : outputs) { | 
 |     const auto name_lookup = required_outputs_.find(cur_output_name); | 
 |     if (name_lookup == required_outputs_.end() || | 
 |         name_lookup->second >= interpreter_->tensors_size()) { | 
 |       std::move(callback).Run(ExecuteResult::UNKNOWN_OUTPUT_ERROR, | 
 |                               base::nullopt); | 
 |       request_metrics.RecordRequestEvent(ExecuteResult::UNKNOWN_OUTPUT_ERROR); | 
 |       return; | 
 |     } | 
 |  | 
 |     // Specifying the same output twice is an error. | 
 |     const auto insert_result = seen_outputs.insert(cur_output_name); | 
 |     if (!insert_result.second) { | 
 |       std::move(callback).Run(ExecuteResult::DUPLICATE_OUTPUT_ERROR, | 
 |                               base::nullopt); | 
 |       request_metrics.RecordRequestEvent(ExecuteResult::DUPLICATE_OUTPUT_ERROR); | 
 |       return; | 
 |     } | 
 |   } | 
 |   if (outputs.size() != required_outputs_.size()) { | 
 |     std::move(callback).Run(ExecuteResult::OUTPUT_MISSING_ERROR, base::nullopt); | 
 |     request_metrics.RecordRequestEvent(ExecuteResult::OUTPUT_MISSING_ERROR); | 
 |     return; | 
 |   } | 
 |  | 
 |   // Copy input data into the interpreter. | 
 |   for (const auto& kv : tensors) { | 
 |     const std::string& cur_input_name = kv.first; | 
 |     const TensorPtr& cur_input = kv.second; | 
 |  | 
 |     // Always valid, by the input name check at the start of this function. | 
 |     const int cur_input_id = required_inputs_.find(cur_input_name)->second; | 
 |  | 
 |     // Check that the current input node is a supported type. | 
 |     const uint32_t cur_input_type = interpreter_->tensor(cur_input_id)->type; | 
 |     if (cur_input_type >= base::size(kPopulateInputFns)) { | 
 |       LOG(ERROR) << "TF lite graph contains invalid input node " << cur_input_id | 
 |                  << " of type " << cur_input_type << "."; | 
 |       std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt); | 
 |       request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR); | 
 |       return; | 
 |     } | 
 |  | 
 |     // Attempt to copy input data into the current input node. | 
 |     const ExecuteResult populate_input_result = | 
 |         (*kPopulateInputFns[cur_input_type])(cur_input, cur_input_id, | 
 |                                              interpreter_.get()); | 
 |     if (populate_input_result != ExecuteResult::OK) { | 
 |       std::move(callback).Run(populate_input_result, base::nullopt); | 
 |       request_metrics.RecordRequestEvent(populate_input_result); | 
 |       return; | 
 |     } | 
 |   } | 
 |  | 
 |   // Execute graph. | 
 |   if (interpreter_->Invoke() != kTfLiteOk) { | 
 |     LOG(ERROR) << "TF lite graph execution failed unexpectedly."; | 
 |     std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt); | 
 |     request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR); | 
 |     return; | 
 |   } | 
 |  | 
 |   // Extract output. | 
 |   std::vector<chromeos::machine_learning::mojom::TensorPtr> output_tensors; | 
 |   for (const auto& cur_output_name : outputs) { | 
 |     output_tensors.push_back(Tensor::New()); | 
 |  | 
 |     // Always valid, by the output name check at the start of this function. | 
 |     const int cur_output_id = | 
 |         required_outputs_.find(cur_output_name)->second; | 
 |  | 
 |     // Check that the current output node is a supported type. | 
 |     const uint32_t cur_output_type = interpreter_->tensor(cur_output_id)->type; | 
 |     if (cur_output_type >= base::size(kPopulateOutputFns)) { | 
 |       LOG(ERROR) << "TF lite graph contains invalid output node " | 
 |                  << cur_output_id << " of type " << cur_output_type << "."; | 
 |       std::move(callback).Run(ExecuteResult::EXECUTION_ERROR, base::nullopt); | 
 |       request_metrics.RecordRequestEvent(ExecuteResult::EXECUTION_ERROR); | 
 |       return; | 
 |     } | 
 |  | 
 |     // Attempt to extract data from the current output node. | 
 |     const ExecuteResult populate_output_result = | 
 |         (*kPopulateOutputFns[cur_output_type])(cur_output_id, *interpreter_, | 
 |                                                *--output_tensors.end()); | 
 |     if (populate_output_result != ExecuteResult::OK) { | 
 |       std::move(callback).Run(populate_output_result, base::nullopt); | 
 |       request_metrics.RecordRequestEvent(populate_output_result); | 
 |       return; | 
 |     } | 
 |   } | 
 |  | 
 |   std::move(callback).Run(ExecuteResult::OK, std::move(output_tensors)); | 
 |   request_metrics.FinishRecordingPerformanceMetrics(); | 
 |   request_metrics.RecordRequestEvent(ExecuteResult::OK); | 
 | } | 
 |  | 
 | }  // namespace ml |