blob: 9d33b03802461baa31eb1ede77d363b6d3bb69df [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <string>
#include <utility>
#include <base/files/file.h>
#include <base/files/file_path.h>
#include <base/files/file_util.h>
#include <base/files/scoped_temp_dir.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "ml/benchmark.h"
#include "ml/benchmark.pb.h"
#include "ml/mojom/model.mojom.h"
#include "proto/benchmark_config.pb.h"
namespace ml {
using ::chrome::ml_benchmark::BenchmarkResults;
using ::chrome::ml_benchmark::BenchmarkReturnStatus;
using ::chrome::ml_benchmark::CrOSBenchmarkConfig;
using ::google::protobuf::TextFormat;
// Test model
constexpr char kSmartDim20181115ModelFile[] =
"/opt/google/chrome/ml_models/mlservice-model-test_add-20180914.tflite";
// Test input.
constexpr char kModelProtoText[] = R"(
required_inputs: {
key: "x"
value: {
index: 1
dims: [1]
}
}
required_inputs: {
key: "y"
value: {
index: 2
dims: [1]
}
}
required_outputs: {
key: "z"
value: {
index: 0
dims: [1]
}
}
)";
constexpr char kInputOutputText[] = R"(
input: {
features: {
feature: {
key: "x"
value: {
float_list: { value:[ 0.5 ] }
}
}
feature: {
key: "y"
value: {
float_list: { value:[ 0.25 ] }
}
}
}
}
expected_output:{
features: {
feature: {
key: "z"
value: {
float_list: { value: [ 0.75 ] }
}
}
}
}
)";
class MlBenchmarkTest : public ::testing::Test {
public:
MlBenchmarkTest() {
// Set benchmark_config_;
CHECK(temp_dir_.CreateUniqueTempDir());
const base::FilePath tflite_model_filepath =
temp_dir_.GetPath().Append("model.pb");
input_output_filepath_ = temp_dir_.GetPath().Append("input_output.pb");
TfliteBenchmarkConfig tflite_config;
tflite_config.set_tflite_model_filepath(tflite_model_filepath.value());
tflite_config.set_input_output_filepath(input_output_filepath_.value());
tflite_config.set_num_runs(100);
TextFormat::PrintToString(tflite_config,
benchmark_config_.mutable_driver_config());
// Set FlatBufferModelSpecProto;
FlatBufferModelSpecProto model_proto;
CHECK(TextFormat::ParseFromString(kModelProtoText, &model_proto));
base::ReadFileToString(base::FilePath(kSmartDim20181115ModelFile),
model_proto.mutable_model_string());
const std::string model_content = model_proto.SerializeAsString();
base::WriteFile(tflite_model_filepath, model_content.data(),
model_content.size());
// Set ExpectedInputOutput.
SetExpectedValue(0.75f);
}
MlBenchmarkTest(const MlBenchmarkTest&) = delete;
MlBenchmarkTest& operator=(const MlBenchmarkTest&) = delete;
// Write the output with given expected value.
void SetExpectedValue(const float val) {
ExpectedInputOutput input_output;
CHECK(TextFormat::ParseFromString(kInputOutputText, &input_output));
(*(*input_output.mutable_expected_output()
->mutable_features()
->mutable_feature())["z"]
.mutable_float_list()
->mutable_value())[0] = val;
const std::string input_content = input_output.SerializeAsString();
base::WriteFile(input_output_filepath_, input_content.data(),
input_content.size());
}
protected:
// Temporary directory containing a file used for the file mechanism.
base::ScopedTempDir temp_dir_;
base::FilePath input_output_filepath_;
CrOSBenchmarkConfig benchmark_config_;
};
TEST_F(MlBenchmarkTest, TfliteModelMatchedValueTest) {
// Step 1 run benchmark_start.
const std::string config = benchmark_config_.SerializeAsString();
void* results_data = nullptr;
int results_size = 0;
EXPECT_EQ(benchmark_start(config.c_str(), config.size(), &results_data,
&results_size),
BenchmarkReturnStatus::OK);
// Step 2 check results.
BenchmarkResults results;
CHECK(results.ParseFromArray(results_data, results_size));
free_benchmark_results(results_data);
EXPECT_EQ(results.status(), BenchmarkReturnStatus::OK);
EXPECT_LT(results.total_accuracy(), 1e-5);
}
TEST_F(MlBenchmarkTest, TfliteModelUnmachedValueTest) {
SetExpectedValue(0.76f);
// Step 1 run benchmark_start.
const std::string config = benchmark_config_.SerializeAsString();
void* results_data = nullptr;
int results_size = 0;
EXPECT_EQ(benchmark_start(config.c_str(), config.size(), &results_data,
&results_size),
BenchmarkReturnStatus::OK);
// Step 2 check results.
BenchmarkResults results;
CHECK(results.ParseFromArray(results_data, results_size));
free_benchmark_results(results_data);
EXPECT_EQ(results.status(), BenchmarkReturnStatus::OK);
EXPECT_LT(results.total_accuracy() - 0.01f, 1e-5);
}
} // namespace ml