| // Copyright 2021 The ChromiumOS Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "federated/example_database.h" |
| |
| #include <cinttypes> |
| #include <string> |
| #include <unordered_set> |
| |
| #include <base/files/file_path.h> |
| #include <base/logging.h> |
| #include <base/strings/string_number_conversions.h> |
| #include <base/strings/string_util.h> |
| #include <base/strings/stringprintf.h> |
| #include <bits/stdint-intn.h> |
| #include <sqlite3.h> |
| |
| #include "federated/utils.h" |
| |
| namespace federated { |
| |
| namespace { |
| |
| // Populates SQL where clause with given start and end time if they are not zero |
| // values, returns an empty string (i.e. no-op) otherwise. |
| std::string MaybeWhereClause(const base::Time& start_time, |
| const base::Time& end_time) { |
| if (start_time == base::Time() && end_time == base::Time()) |
| return std::string(); |
| |
| DCHECK(start_time <= end_time) |
| << "Invalid time range: start_time must <= end_time"; |
| DCHECK(start_time > base::Time::UnixEpoch()) |
| << "Invalid time range: start_time must > UnixEpoch()"; |
| return base::StringPrintf("WHERE timestamp>=%" PRId64 |
| " AND timestamp<=%" PRId64, |
| start_time.ToJavaTime(), end_time.ToJavaTime()); |
| } |
| |
| // Used in ExampleCount to return number of examples. |
| int ExampleCountCallback(void* const /* int* const */ data, |
| const int col_count, |
| char** const cols, |
| char** const /* names */) { |
| CHECK(data != nullptr); |
| CHECK(cols != nullptr); |
| |
| int example_count = 0; |
| if (col_count != 1 || cols[0] == nullptr || |
| !base::StringToInt(cols[0], &example_count)) { |
| LOG(ERROR) << "Invalid example count results"; |
| return SQLITE_ERROR; |
| } |
| |
| auto* const output = static_cast<int*>(data); |
| *output = example_count; |
| return SQLITE_OK; |
| } |
| |
| // Used in CheckIntegrity to extract state code and result string from SQL |
| // exec. |
| int IntegrityCheckCallback(void* const /* std::string* const */ data, |
| int const col_count, |
| char** const cols, |
| char** const /* names */) { |
| CHECK(data != nullptr); |
| CHECK(cols != nullptr); |
| |
| if (col_count != 1 || cols[0] == nullptr) { |
| LOG(ERROR) << "Invalid integrity check results"; |
| return SQLITE_ERROR; |
| } |
| |
| auto* const integrity_result = static_cast<std::string*>(data); |
| integrity_result->assign(cols[0]); |
| return SQLITE_OK; |
| } |
| |
| // Used in ClientTableExists to extract state code and table_count from SQL |
| // exec. |
| int ClientTableExistsCallback(void* const /* int* const */ data, |
| const int col_count, |
| char** const cols, |
| char** const /* names */) { |
| CHECK(data != nullptr); |
| CHECK(cols != nullptr); |
| |
| auto* const table_count = static_cast<int*>(data); |
| if (col_count != 1 || cols[0] == nullptr || |
| !base::StringToInt(cols[0], table_count)) { |
| LOG(ERROR) << "Table existence check failed"; |
| return SQLITE_ERROR; |
| } |
| return SQLITE_OK; |
| } |
| |
| } // namespace |
| |
| ExampleDatabase::Iterator::Iterator() : stmt_(nullptr) {} |
| |
| ExampleDatabase::Iterator::Iterator(sqlite3* const db, |
| const std::string& client_name, |
| const base::Time& start_time, |
| const base::Time& end_time) { |
| if (db == nullptr) { |
| stmt_ = nullptr; |
| return; |
| } |
| |
| const std::string sql_code = base::StringPrintf( |
| "SELECT id, example, timestamp FROM %s %s ORDER BY id;", |
| client_name.c_str(), MaybeWhereClause(start_time, end_time).c_str()); |
| const int result = |
| sqlite3_prepare_v2(db, sql_code.c_str(), -1, &stmt_, nullptr); |
| |
| if (result != SQLITE_OK) { |
| LOG(ERROR) << "Couldn't compile iteration statement: " |
| << sqlite3_errmsg(db); |
| Close(); |
| } |
| } |
| |
| ExampleDatabase::Iterator::Iterator(sqlite3* const db, |
| const std::string& client_name) |
| : ExampleDatabase::Iterator::Iterator( |
| db, client_name, base::Time(), base::Time()) {} |
| |
| ExampleDatabase::Iterator::Iterator(ExampleDatabase::Iterator&& o) |
| : stmt_(o.stmt_) { |
| o.stmt_ = nullptr; |
| } |
| |
| ExampleDatabase::Iterator& ExampleDatabase::Iterator::operator=( |
| Iterator&& other) { |
| if (stmt_ != nullptr) { |
| Close(); |
| } |
| stmt_ = other.stmt_; |
| other.stmt_ = nullptr; |
| |
| return *this; |
| } |
| |
| ExampleDatabase::Iterator::~Iterator() { |
| Close(); |
| } |
| |
| absl::StatusOr<ExampleRecord> ExampleDatabase::Iterator::Next() { |
| if (stmt_ == nullptr) { |
| return absl::InvalidArgumentError("Invalid sqlite3 statment"); |
| } |
| |
| // Execute retrieval step. |
| const int code = sqlite3_step(stmt_); |
| if (code == SQLITE_DONE) { |
| Close(); |
| return absl::OutOfRangeError("End of iterator reached"); |
| } |
| if (code != SQLITE_ROW) { |
| Close(); |
| return absl::InvalidArgumentError("Couldn't retrieve next example"); |
| } |
| |
| // Extract step results. |
| const int64_t id = sqlite3_column_int64(stmt_, 0); |
| const unsigned char* const example_buffer = |
| static_cast<const unsigned char*>(sqlite3_column_blob(stmt_, 1)); |
| const int example_buffer_len = sqlite3_column_bytes(stmt_, 1); |
| const int64_t java_ts = sqlite3_column_int64(stmt_, 2); |
| if (id <= 0 || example_buffer == nullptr || example_buffer_len <= 0 || |
| java_ts < 0) { |
| Close(); |
| return absl::InvalidArgumentError("Failed to extract example"); |
| } |
| |
| // Populate output struct. |
| ExampleRecord example_record; |
| example_record.id = id; |
| example_record.serialized_example = |
| std::string(example_buffer, example_buffer + example_buffer_len); |
| example_record.timestamp = base::Time::FromJavaTime(java_ts); |
| return example_record; |
| } |
| |
| void ExampleDatabase::Iterator::Close() { |
| sqlite3_finalize(stmt_); |
| stmt_ = nullptr; |
| } |
| |
| ExampleDatabase::ExampleDatabase(const base::FilePath& db_path) |
| : db_path_(db_path), db_(nullptr, nullptr) { |
| // Checks that sqlite3 is compiled with threading mode = Serialized, so that |
| // the db connection can be used in multiple threads. See |
| // https://www.sqlite.org/threadsafe.html for more information. |
| DCHECK_EQ(1, sqlite3_threadsafe()); |
| } |
| |
| ExampleDatabase::~ExampleDatabase() { |
| Close(); |
| } |
| |
| bool ExampleDatabase::Init(const std::unordered_set<std::string>& clients) { |
| // SQLITE_OPEN_FULLMUTEX means sqlite3 threadding mode = serialized so that |
| // it's safe to access the same database connection in multiple threads. |
| // This is the default when `sqlite3_threadsafe() == 1` but no harm to make a |
| // double insurance with this flag. |
| sqlite3* db_ptr; |
| const int result = sqlite3_open_v2( |
| db_path_.MaybeAsASCII().c_str(), &db_ptr, |
| SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, NULL); |
| db_ = std::unique_ptr<sqlite3, decltype(&sqlite3_close)>(db_ptr, |
| &sqlite3_close); |
| if (result != SQLITE_OK) { |
| LOG(ERROR) << "Failed to connect to database: " |
| << sqlite3_errmsg(db_.get()); |
| db_ = nullptr; |
| return false; |
| } |
| |
| for (const auto& client : clients) { |
| if ((!ClientTableExists(client) && !CreateClientTable(client))) { |
| LOG(ERROR) << "Failed to prepare table for client " << client; |
| Close(); |
| |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| bool ExampleDatabase::IsOpen() const { |
| return db_.get() != nullptr; |
| } |
| |
| bool ExampleDatabase::Close() { |
| if (!IsOpen()) { |
| return true; |
| } |
| |
| // If the database is successfully closed, db_ pointer must be released. |
| // Otherwise sqlite3_close will be called again on already released db_ |
| // pointer by the destructor, which will result in undefined behavior. |
| int result = sqlite3_close(db_.get()); |
| if (result != SQLITE_OK) { |
| // This should never happen |
| LOG(ERROR) << "Failed to close database: " << sqlite3_errmsg(db_.get()); |
| return false; |
| } |
| |
| db_.release(); |
| return true; |
| } |
| |
| bool ExampleDatabase::CheckIntegrity() const { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to check integrity of a closed database"; |
| return false; |
| } |
| |
| // Integrity_check(N) returns a single row and a single column with string |
| // "ok" if there is no error. Otherwise a maximum of N rows are returned |
| // with each row representing a single error. |
| std::string integrity_result; |
| ExecResult result = ExecSql("PRAGMA integrity_check(1)", |
| IntegrityCheckCallback, &integrity_result); |
| if (result.code != SQLITE_OK) { |
| LOG(ERROR) << "Failed to check integrity: " << result.error_msg; |
| return false; |
| } |
| |
| return integrity_result == "ok"; |
| } |
| |
| ExampleDatabase::Iterator ExampleDatabase::GetIterator( |
| const std::string& client_name, |
| const base::Time& start_time, |
| const base::Time& end_time) const { |
| return Iterator(db_.get(), client_name, start_time, end_time); |
| } |
| |
| ExampleDatabase::Iterator ExampleDatabase::GetIteratorForTesting( |
| const std::string& client_name) const { |
| return Iterator(db_.get(), client_name); |
| } |
| |
| bool ExampleDatabase::InsertExample(const std::string& client_name, |
| const ExampleRecord& example_record) { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to insert example into a closed database"; |
| return false; |
| } |
| |
| // Compile the insertion statement. |
| sqlite3_stmt* stmt = nullptr; |
| const std::string sql_code = |
| base::StringPrintf("INSERT INTO %s (example, timestamp) VALUES (?, ?);", |
| client_name.c_str()); |
| const int result = |
| sqlite3_prepare_v2(db_.get(), sql_code.c_str(), -1, &stmt, nullptr); |
| if (result != SQLITE_OK) { |
| LOG(ERROR) << "Couldn't compile insertion statement: " |
| << sqlite3_errmsg(db_.get()); |
| return false; |
| } |
| |
| // Run the insertion statement. |
| const bool ok = |
| sqlite3_bind_blob(stmt, 1, example_record.serialized_example.c_str(), |
| example_record.serialized_example.length(), |
| nullptr) == SQLITE_OK && |
| sqlite3_bind_int64(stmt, 2, example_record.timestamp.ToJavaTime()) == |
| SQLITE_OK && |
| sqlite3_step(stmt) == SQLITE_DONE; |
| sqlite3_finalize(stmt); |
| |
| if (!ok) { |
| LOG(ERROR) << "Failed to insert example: " << sqlite3_errmsg(db_.get()); |
| } |
| DVLOG(1) << "Insert example for client " << client_name; |
| return ok; |
| } |
| |
| void ExampleDatabase::DeleteAllExamples(const std::string& client_name) { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to delete from a closed database"; |
| return; |
| } |
| |
| const ExecResult result = |
| ExecSql(base::StringPrintf("DELETE FROM %s;", client_name.c_str())); |
| if (result.code != SQLITE_OK) { |
| LOG(ERROR) << "Failed to delete examples: " << result.error_msg; |
| } |
| } |
| |
| bool ExampleDatabase::ClientTableExists(const std::string& client_name) const { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to query table of a closed database"; |
| return false; |
| } |
| |
| int table_count = 0; |
| const std::string sql_code = base::StringPrintf( |
| "SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = " |
| "'%s';", |
| client_name.c_str()); |
| ExecResult result = |
| ExecSql(sql_code, ClientTableExistsCallback, &table_count); |
| |
| if (result.code != SQLITE_OK) { |
| LOG(ERROR) << "Failed to query table existence: " << result.error_msg; |
| return false; |
| } |
| |
| if (table_count <= 0) |
| return false; |
| |
| DCHECK(table_count == 1) << "There should be only one table with name '" |
| << client_name << "'"; |
| |
| return true; |
| } |
| |
| bool ExampleDatabase::CreateClientTable(const std::string& client_name) { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to create table in a closed database"; |
| return false; |
| } |
| |
| const std::string sql = base::StringPrintf(R"( |
| CREATE TABLE %s ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT |
| NOT NULL, |
| example BLOB NOT NULL, |
| timestamp INTEGER NOT NULL |
| ))", |
| client_name.c_str()); |
| const ExecResult result = ExecSql(sql); |
| if (result.code != SQLITE_OK) { |
| LOG(ERROR) << "Failed to create table: " << result.error_msg; |
| return false; |
| } |
| return true; |
| } |
| |
| int ExampleDatabase::ExampleCount(const std::string& client_name, |
| const base::Time& start_time, |
| const base::Time& end_time) const { |
| DCHECK(start_time != base::Time() && end_time != base::Time()) |
| << "start_time and end_time cannot be zero values"; |
| return ExampleCountInternal(client_name, |
| MaybeWhereClause(start_time, end_time)); |
| } |
| |
| int ExampleDatabase::ExampleCountForTesting( |
| const std::string& client_name) const { |
| return ExampleCountInternal(client_name, /* where_clause = */ std::string()); |
| } |
| |
| int ExampleDatabase::ExampleCountInternal( |
| const std::string& client_name, const std::string& where_clause) const { |
| if (!IsOpen()) { |
| LOG(ERROR) << "Trying to count examples in a closed database"; |
| return 0; |
| } |
| |
| int count = 0; |
| const ExecResult result = |
| ExecSql(base::StringPrintf("SELECT COUNT(*) FROM %s %s;", |
| client_name.c_str(), where_clause.c_str()), |
| &ExampleCountCallback, &count); |
| |
| if (result.code != SQLITE_OK) { |
| LOG(ERROR) << "Failed to count examples: " << result.error_msg; |
| return 0; |
| } |
| return count; |
| } |
| |
| ExampleDatabase::ExecResult ExampleDatabase::ExecSql( |
| const std::string& sql) const { |
| return ExecSql(sql, nullptr, nullptr); |
| } |
| |
| ExampleDatabase::ExecResult ExampleDatabase::ExecSql(const std::string& sql, |
| SqliteCallback callback, |
| void* const data) const { |
| char* error_msg = nullptr; |
| const int result = |
| sqlite3_exec(db_.get(), sql.c_str(), callback, data, &error_msg); |
| // According to sqlite3_exec() documentation, error_msg points to memory |
| // allocated by sqlite3_malloc(), which must be freed by sqlite3_free(). |
| std::string error_msg_str; |
| if (error_msg) { |
| error_msg_str.assign(error_msg); |
| sqlite3_free(error_msg); |
| } |
| return {result, error_msg_str}; |
| } |
| |
| } // namespace federated |