blob: 3a0abd3d936e5e04d4ae03e7a5aa046f5a2fabf2 [file] [log] [blame]
diff --git a/tensorflow/lite/examples/label_image/BUILD b/tensorflow/lite/examples/label_image/BUILD
index 88e5fd2c..3c229f0e 100644
--- a/tensorflow/lite/examples/label_image/BUILD
+++ b/tensorflow/lite/examples/label_image/BUILD
@@ -29,6 +29,7 @@ cc_binary(
":bitmap_helpers",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
+ "//tensorflow/lite/delegates/nnapi:nnapi_delegate",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/profiling:profiler",
],
diff --git a/tensorflow/lite/examples/label_image/label_image.cc b/tensorflow/lite/examples/label_image/label_image.cc
index ac84e270..dd358e6b 100644
--- a/tensorflow/lite/examples/label_image/label_image.cc
+++ b/tensorflow/lite/examples/label_image/label_image.cc
@@ -26,12 +26,14 @@ limitations under the License.
#include <fstream>
#include <iomanip>
#include <iostream>
+#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>
+#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/examples/label_image/bitmap_helpers.h"
#include "tensorflow/lite/examples/label_image/get_top_n.h"
#include "tensorflow/lite/kernels/register.h"
@@ -47,6 +49,24 @@ namespace label_image {
double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); }
+using TfLiteDelegatePtr = tflite::Interpreter::TfLiteDelegatePtr;
+using TfLiteDelegatePtrMap = std::map<std::string, TfLiteDelegatePtr>;
+
+Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
+ return Interpreter::TfLiteDelegatePtr(
+ NnApiDelegate(),
+ // NnApiDelegate() returns a singleton, so provide a no-op deleter.
+ [](TfLiteDelegate*) {});
+}
+
+TfLiteDelegatePtrMap GetDelegates(Settings* s) {
+ TfLiteDelegatePtrMap delegates;
+ if (s->accel) {
+ delegates.emplace("NNAPI", CreateNNAPIDelegate());
+ }
+ return delegates;
+}
+
// Takes a file name, and loads a list of labels from it, one per line, and
// returns a vector of the strings. It pads with empty strings so the length
// of the result is a multiple of 16, because our model expects that.
@@ -100,6 +120,7 @@ void RunInference(Settings* s) {
LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n";
exit(-1);
}
+ s->model = model.get();
LOG(INFO) << "Loaded model " << s->model_name << "\n";
model->error_reporter();
LOG(INFO) << "resolved reporter\n";
@@ -112,7 +133,7 @@ void RunInference(Settings* s) {
exit(-1);
}
- interpreter->UseNNAPI(s->accel);
+ interpreter->UseNNAPI(s->old_accel);
interpreter->SetAllowFp16PrecisionForFp32(s->allow_fp16);
if (s->verbose) {
@@ -153,6 +174,16 @@ void RunInference(Settings* s) {
LOG(INFO) << "number of outputs: " << outputs.size() << "\n";
}
+ auto delegates_ = GetDelegates(s);
+ for (const auto& delegate : delegates_) {
+ if (interpreter->ModifyGraphWithDelegate(delegate.second.get()) !=
+ kTfLiteOk) {
+ LOG(FATAL) << "Failed to apply " << delegate.first << " delegate.";
+ } else {
+ LOG(INFO) << "Applied " << delegate.first << " delegate.";
+ }
+ }
+
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
@@ -255,6 +286,7 @@ void display_usage() {
LOG(INFO)
<< "label_image\n"
<< "--accelerated, -a: [0|1], use Android NNAPI or not\n"
+ << "--old_accelerated, -d: [0|1], use olf Android NNAPI delegate or not\n"
<< "--allow_fp16, -f: [0|1], allow running fp32 models with fp16 not\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
<< "--input_mean, -b: input mean\n"
@@ -276,6 +308,7 @@ int Main(int argc, char** argv) {
while (1) {
static struct option long_options[] = {
{"accelerated", required_argument, nullptr, 'a'},
+ {"old_accelerated", required_argument, nullptr, 'd'},
{"allow_fp16", required_argument, nullptr, 'f'},
{"count", required_argument, nullptr, 'c'},
{"verbose", required_argument, nullptr, 'v'},
@@ -292,7 +325,7 @@ int Main(int argc, char** argv) {
/* getopt_long stores the option index here. */
int option_index = 0;
- c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:r:s:t:v:", long_options,
+ c = getopt_long(argc, argv, "a:b:c:d:f:i:l:m:p:r:s:t:v:", long_options,
&option_index);
/* Detect the end of the options. */
@@ -309,6 +342,9 @@ int Main(int argc, char** argv) {
s.loop_count =
strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
break;
+ case 'd':
+ s.old_accel =
+ strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
case 'f':
s.allow_fp16 =
strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn)
diff --git a/tensorflow/lite/examples/label_image/label_image.h b/tensorflow/lite/examples/label_image/label_image.h
index cc46e56b..ddfb09e2 100644
--- a/tensorflow/lite/examples/label_image/label_image.h
+++ b/tensorflow/lite/examples/label_image/label_image.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H_
#include "tensorflow/lite/string.h"
+#include "tensorflow/lite/model.h"
namespace tflite {
namespace label_image {
@@ -24,6 +25,7 @@ namespace label_image {
struct Settings {
bool verbose = false;
bool accel = false;
+ bool old_accel = false;
bool input_floating = false;
bool profiling = false;
bool allow_fp16 = false;
@@ -31,6 +33,7 @@ struct Settings {
float input_mean = 127.5f;
float input_std = 127.5f;
string model_name = "./mobilenet_quant_v1_224.tflite";
+ tflite::FlatBufferModel* model;
string input_bmp_name = "./grace_hopper.bmp";
string labels_file_name = "./labels.txt";
string input_layer_type = "uint8_t";
diff --git a/tensorflow/lite/tools/evaluation/utils.cc b/tensorflow/lite/tools/evaluation/utils.cc
index 1154953e..c832bbd1 100644
--- a/tensorflow/lite/tools/evaluation/utils.cc
+++ b/tensorflow/lite/tools/evaluation/utils.cc
@@ -80,14 +80,10 @@ TfLiteStatus GetSortedFileNames(const std::string& directory,
}
Interpreter::TfLiteDelegatePtr CreateNNAPIDelegate() {
-#if defined(__ANDROID__)
return Interpreter::TfLiteDelegatePtr(
NnApiDelegate(),
// NnApiDelegate() returns a singleton, so provide a no-op deleter.
[](TfLiteDelegate*) {});
-#else
- return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
-#endif // defined(__ANDROID__)
}
Interpreter::TfLiteDelegatePtr CreateGPUDelegate(