blob: 9248e2b216e6c49c08f419df6a40c95f55078666 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "federated/example_database.h"
#include <string>
#include <unordered_set>
#include <base/files/file_path.h>
#include <base/files/file_util.h>
#include <base/files/scoped_temp_dir.h>
#include <base/strings/stringprintf.h>
#include <base/time/time.h>
#include <base/optional.h>
#include <gtest/gtest.h>
#include <sqlite3.h>
#include "federated/example_database_test_utils.h"
#include "federated/utils.h"
namespace federated {
namespace {
const std::unordered_set<std::string> kTestClients = {"test_client_1",
"test_client_2"};
} // namespace
class ExampleDatabaseTest : public testing::Test {
public:
ExampleDatabaseTest(const ExampleDatabaseTest&) = delete;
ExampleDatabaseTest& operator=(const ExampleDatabaseTest&) = delete;
ExampleDatabase* get_db() const { return db_; }
const base::FilePath& temp_path() const { return temp_dir_.GetPath(); }
// Prepares a database, table test_client_1 has 100 examples (id from 1
// to 100), table test_client_2 is created by db_->Init() and is empty.
bool CreateExampleDatabaseAndInitialize() {
const base::FilePath db_path =
temp_dir_.GetPath().Append(kDatabaseFileName);
if (CreateDatabaseForTesting(db_path) != SQLITE_OK) {
LOG(ERROR) << "Failed to create database file.";
return false;
}
db_ = new ExampleDatabase(db_path, kTestClients);
if (!db_->Init() || !db_->IsOpen() || !db_->CheckIntegrity()) {
LOG(ERROR) << "Failed to initialize or check integrity of db_.";
return false;
}
return true;
}
protected:
ExampleDatabaseTest() = default;
void SetUp() override { ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); }
void TearDown() override {
ASSERT_TRUE(temp_dir_.Delete());
if (db_)
delete db_;
}
private:
base::ScopedTempDir temp_dir_;
ExampleDatabase* db_ = nullptr;
};
// This test runs the same steps as CreateExampleDatabaseAndInitialize, but
// checks step by step.
TEST_F(ExampleDatabaseTest, CreateDatabase) {
// Prepares a database file.
base::FilePath db_path = temp_path().Append(kDatabaseFileName);
ASSERT_EQ(CreateDatabaseForTesting(db_path), SQLITE_OK);
EXPECT_TRUE(base::PathExists(db_path));
// Creates the db instance.
ExampleDatabase db(db_path, kTestClients);
// Initializes the db and checks integrity.
EXPECT_TRUE(db.Init());
EXPECT_TRUE(db.IsOpen());
EXPECT_TRUE(db.CheckIntegrity());
// Closes it.
EXPECT_TRUE(db.Close());
}
TEST_F(ExampleDatabaseTest, DatabaseQueryFromNonEmptyTable) {
ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
ExampleDatabase& db = *get_db();
std::string client_name = "test_client_1";
// Table test_client_1 has 100 records, limit=50 can return 50 examples.
int limit = 50;
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, limit));
int64_t count = 0;
while (true) { // id from 1 to 50.
auto maybe_record = db.GetNextStreamedRecord();
if (maybe_record == base::nullopt)
break;
EXPECT_EQ(maybe_record.value().id, count + 1);
EXPECT_EQ(maybe_record.value().serialized_example,
base::StringPrintf("example_%zu", count + 1));
count++;
}
EXPECT_EQ(count, 50);
db.CloseStreaming();
// Limit=150 only returns 100 examples, that's all in the table.
limit = 150;
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, limit));
count = 0;
while (true) { // id from 1 to 100.
auto maybe_record = db.GetNextStreamedRecord();
if (maybe_record == base::nullopt)
break;
EXPECT_EQ(maybe_record.value().id, count + 1);
EXPECT_EQ(maybe_record.value().serialized_example,
base::StringPrintf("example_%zu", count + 1));
count++;
}
EXPECT_EQ(count, 100);
db.CloseStreaming();
EXPECT_TRUE(db.Close());
}
TEST_F(ExampleDatabaseTest, PrepareStreamingFailure) {
ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
ExampleDatabase& db = *get_db();
int limit = 100;
// Table test_client_2 is empty, returns false on PrepareStreamingForClient
std::string client_name = "test_client_2";
EXPECT_FALSE(db.PrepareStreamingForClient(client_name, limit));
// Table test_client_3 doesn't exist, returns false on
// PrepareStreamingForClient.
client_name = "test_client_3";
EXPECT_FALSE(db.PrepareStreamingForClient(client_name, limit));
EXPECT_TRUE(db.Close());
}
// Test that example_database can handle some illegal query operations.
TEST_F(ExampleDatabaseTest, UnexpectedQuery) {
ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
ExampleDatabase& db = *get_db();
// Insert an example to table test_client_2.
ExampleRecord record = {-1 /*placeholder for id*/, "test_client_2",
"manually_inserted_example", base::Time::Now()};
EXPECT_TRUE(db.InsertExample(record));
std::string client_name = "test_client_1";
// Calls GetNextStreamedRecord without PrepareStreamingForClient, gets
// base::nullopt.
EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
// Calls GetNextStreamedRecord after CloseStreaming, gets base::nullopt.
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 10));
db.CloseStreaming();
EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
// Calls GetNextStreamedRecord after it already returned base::nullopt.
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 10));
int count = 0;
while (db.GetNextStreamedRecord() != base::nullopt)
count++;
EXPECT_EQ(count, 10);
EXPECT_EQ(db.GetNextStreamedRecord(), base::nullopt);
db.CloseStreaming();
// A subsequent PrepareStreamingForClient call before CloseStreaming() will
// fail and have no influnce on the existing streaming.
EXPECT_TRUE(db.PrepareStreamingForClient("test_client_1", 10));
count = 0;
for (size_t i = 0; i < 5; i++) {
EXPECT_NE(db.GetNextStreamedRecord(), base::nullopt);
count++;
}
EXPECT_FALSE(db.PrepareStreamingForClient("test_client_2", 50));
while (db.GetNextStreamedRecord() != base::nullopt)
count++;
EXPECT_EQ(count, 10);
// Subsequent PrepareStreamingForClient fails as long as the previous
// streaming is not closed by CloseStreaming, even if it already hit the end.
EXPECT_FALSE(db.PrepareStreamingForClient("test_client_2", 50));
db.CloseStreaming();
EXPECT_TRUE(db.PrepareStreamingForClient("test_client_2", 50));
}
TEST_F(ExampleDatabaseTest, InsertExample) {
ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
ExampleDatabase& db = *get_db();
ExampleRecord record;
record.serialized_example = "manually_inserted_example";
record.timestamp = base::Time::Now();
// Before inserting, table test_client_2 is empty, returns false on
// PrepareStreamingForClient.
std::string client_name = "test_client_2";
record.client_name = client_name;
EXPECT_FALSE(db.PrepareStreamingForClient(client_name, 100));
EXPECT_TRUE(db.InsertExample(record));
// After inserting, GetNextStreamedRecord will succeed for 1 time.
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 100));
int64_t count = 0;
while (db.GetNextStreamedRecord() != base::nullopt)
count++;
EXPECT_EQ(count, 1);
db.CloseStreaming();
// Fails to insert into a non-existing table;
client_name = "test_client_3";
record.client_name = client_name;
EXPECT_FALSE(db.InsertExample(record));
EXPECT_TRUE(db.Close());
}
TEST_F(ExampleDatabaseTest, DeleteExamples) {
ASSERT_TRUE(CreateExampleDatabaseAndInitialize());
ExampleDatabase& db = *get_db();
std::string client_name = "test_client_1";
// Delete examples with id <= 30 from table test_client_1;
EXPECT_TRUE(db.DeleteExamplesWithSmallerIdForClient(client_name, 30));
EXPECT_TRUE(db.PrepareStreamingForClient(client_name, 100));
int64_t count = 0;
while (true) { // id from 31 to 100.
auto maybe_record = db.GetNextStreamedRecord();
if (maybe_record == base::nullopt)
break;
EXPECT_EQ(maybe_record.value().id, count + 31);
EXPECT_EQ(maybe_record.value().serialized_example,
base::StringPrintf("example_%zu", count + 31));
count++;
}
EXPECT_EQ(count, 70);
db.CloseStreaming();
// No examples with id <= 20 now, returns false;
EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 20));
client_name = "test_client_2";
// Delete examples from an empty table, returns false;
EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 100));
client_name = "test_client_3";
// Delete examples from a non-existing table, returns false;
EXPECT_FALSE(db.DeleteExamplesWithSmallerIdForClient(client_name, 100));
EXPECT_TRUE(db.Close());
}
} // namespace federated