blob: bc0b4496591e90220194d024a417bd24cd39f78c [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <base/bind.h>
#include <base/containers/flat_map.h>
#include <base/macros.h>
#include <base/run_loop.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <mojo/public/cpp/bindings/remote.h>
#include <tensorflow/lite/model.h>
#include "ml/model_impl.h"
#include "ml/mojom/graph_executor.mojom.h"
#include "ml/mojom/model.mojom.h"
#include "ml/tensor_view.h"
#include "ml/test_utils.h"
namespace ml {
namespace {
using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult;
using ::chromeos::machine_learning::mojom::ExecuteResult;
using ::chromeos::machine_learning::mojom::GraphExecutor;
using ::chromeos::machine_learning::mojom::Model;
using ::chromeos::machine_learning::mojom::TensorPtr;
using ::testing::ElementsAre;
using ::testing::Eq;
class ModelImplTest : public testing::Test {
protected:
// Metadata for the example model:
// A simple model that adds up two tensors. Inputs and outputs are 1x1 float
// tensors.
const std::string model_path_ =
GetTestModelDir() + "mlservice-model-test_add-20180914.tflite";
const std::map<std::string, int> model_inputs_ = {{"x", 1}, {"y", 2}};
const std::map<std::string, int> model_outputs_ = {{"z", 0}};
};
// Tests that AlignedModelData ensures that short strings have aligned .c_str().
TEST(AlignedModelData, MaybeUnalignedInput) {
// Short strings can have unaligned .c_str() because they are stored directly
// inside the string struct rather than on the heap.
const std::string test_str = "short string";
std::string maybe_unaligned_str = test_str;
// Note: Whether `maybe_unaligned_str` *actually* has unaligned .c_str()
// depends on the particular impl of std::string. At the time of writing, it
// is indeed unaligned on e.g. amd64-generic.
const AlignedModelData aligned_model_data(std::move(maybe_unaligned_str));
// The .data() should now be aligned.
EXPECT_THAT(reinterpret_cast<std::uintptr_t>(aligned_model_data.data()) % 4,
Eq(0));
// The contents agree.
EXPECT_TRUE(
std::equal(test_str.begin(), test_str.end(), aligned_model_data.data()));
}
// Test loading an invalid model.
TEST_F(ModelImplTest, TestBadModel) {
// Pass nullptr instead of a valid model.
mojo::Remote<Model> model;
ModelImpl::Create(model_inputs_, model_outputs_, nullptr /*model*/,
model.BindNewPipeAndPassReceiver(), "TestModel");
ASSERT_TRUE(model.is_bound());
// Ensure that creating a graph executor fails.
bool callback_done = false;
mojo::Remote<GraphExecutor> graph_executor;
model->CreateGraphExecutor(
graph_executor.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* callback_done, const CreateGraphExecutorResult result) {
EXPECT_EQ(result,
CreateGraphExecutorResult::MODEL_INTERPRETATION_ERROR);
*callback_done = true;
},
&callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(callback_done);
}
// Test loading the valid example model.
TEST_F(ModelImplTest, TestExampleModel) {
// Read the example TF model from disk.
std::unique_ptr<tflite::FlatBufferModel> tflite_model =
tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
ASSERT_NE(tflite_model.get(), nullptr);
// Create model object.
mojo::Remote<Model> model;
ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
model.BindNewPipeAndPassReceiver(), "TestModel");
ASSERT_TRUE(model.is_bound());
// Create a graph executor.
bool cge_callback_done = false;
mojo::Remote<GraphExecutor> graph_executor;
model->CreateGraphExecutor(
graph_executor.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* cge_callback_done, const CreateGraphExecutorResult result) {
EXPECT_EQ(result, CreateGraphExecutorResult::OK);
*cge_callback_done = true;
},
&cge_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(cge_callback_done);
// Construct input/output for graph execution.
base::flat_map<std::string, TensorPtr> inputs;
inputs.emplace("x", NewTensor<double>({1}, {0.5}));
inputs.emplace("y", NewTensor<double>({1}, {0.25}));
std::vector<std::string> outputs({"z"});
// Execute graph.
bool exe_callback_done = false;
graph_executor->Execute(
std::move(inputs), std::move(outputs),
base::Bind(
[](bool* exe_callback_done, const ExecuteResult result,
base::Optional<std::vector<TensorPtr>> outputs) {
// Check that the inference succeeded and gives the expected number
// of outputs.
EXPECT_EQ(result, ExecuteResult::OK);
ASSERT_TRUE(outputs.has_value());
ASSERT_EQ(outputs->size(), 1);
// Check that the output tensor has the right type and format.
const TensorView<double> out_tensor((*outputs)[0]);
EXPECT_TRUE(out_tensor.IsValidType());
EXPECT_TRUE(out_tensor.IsValidFormat());
// Check the output tensor has the expected shape and values.
EXPECT_THAT(out_tensor.GetShape(), ElementsAre(1));
EXPECT_THAT(out_tensor.GetValues(), ElementsAre(0.75));
*exe_callback_done = true;
},
&exe_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(exe_callback_done);
}
TEST_F(ModelImplTest, TestGraphExecutorCleanup) {
// Read the example TF model from disk.
std::unique_ptr<tflite::FlatBufferModel> tflite_model =
tflite::FlatBufferModel::BuildFromFile(model_path_.c_str());
ASSERT_NE(tflite_model.get(), nullptr);
// Create model object.
mojo::Remote<Model> model;
const ModelImpl* model_impl =
ModelImpl::Create(model_inputs_, model_outputs_, std::move(tflite_model),
model.BindNewPipeAndPassReceiver(), "TestModel");
ASSERT_TRUE(model.is_bound());
// Create one graph executor.
bool cge1_callback_done = false;
mojo::Remote<GraphExecutor> graph_executor_1;
model->CreateGraphExecutor(
graph_executor_1.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* cge1_callback_done, const CreateGraphExecutorResult result) {
EXPECT_EQ(result, CreateGraphExecutorResult::OK);
*cge1_callback_done = true;
},
&cge1_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(cge1_callback_done);
ASSERT_TRUE(graph_executor_1.is_bound());
ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
// Create another graph executor.
bool cge2_callback_done = false;
mojo::Remote<GraphExecutor> graph_executor_2;
model->CreateGraphExecutor(
graph_executor_2.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* cge2_callback_done, const CreateGraphExecutorResult result) {
EXPECT_EQ(result, CreateGraphExecutorResult::OK);
*cge2_callback_done = true;
},
&cge2_callback_done));
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(cge2_callback_done);
ASSERT_TRUE(graph_executor_2.is_bound());
ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 2);
// Destroy one graph executor.
graph_executor_1.reset();
base::RunLoop().RunUntilIdle();
ASSERT_TRUE(graph_executor_2.is_bound());
ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 1);
// Destroy the other graph executor.
graph_executor_2.reset();
base::RunLoop().RunUntilIdle();
ASSERT_EQ(model_impl->num_graph_executors_for_testing(), 0);
}
} // namespace
} // namespace ml