Skip to content

Commit

Permalink
Add allow_creation parameter to SqliteDb::init.
Browse files Browse the repository at this point in the history
  • Loading branch information
levlam committed Sep 22, 2021
1 parent f073c79 commit a0cc1be
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 55 deletions.
5 changes: 2 additions & 3 deletions benchmark/bench_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ class SqliteKVBench final : public td::Benchmark {
td::string path = "testdb.sqlite";
td::SqliteDb::destroy(path).ignore();
if (is_encrypted) {
td::SqliteDb::change_key(path, td::DbKey::password("cucumber"), td::DbKey::empty()).ensure();
db = td::SqliteDb::open_with_key(path, td::DbKey::password("cucumber")).move_as_ok();
db = td::SqliteDb::change_key(path, true, td::DbKey::password("cucumber"), td::DbKey::empty()).move_as_ok();
} else {
db = td::SqliteDb::open_with_key(path, td::DbKey::empty()).move_as_ok();
db = td::SqliteDb::open_with_key(path, true, td::DbKey::empty()).move_as_ok();
}
db.exec("PRAGMA encoding=\"UTF-8\"").ensure();
db.exec("PRAGMA synchronous=NORMAL").ensure();
Expand Down
2 changes: 1 addition & 1 deletion td/telegram/LanguagePackManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ bool LanguagePackManager::is_custom_language_code(Slice language_code) {
}

static Result<SqliteDb> open_database(const string &path) {
TRY_RESULT(database, SqliteDb::open_with_key(path, DbKey::empty()));
TRY_RESULT(database, SqliteDb::open_with_key(path, true, DbKey::empty()));
TRY_STATUS(database.exec("PRAGMA synchronous=NORMAL"));
TRY_STATUS(database.exec("PRAGMA temp_store=MEMORY"));
TRY_STATUS(database.exec("PRAGMA encoding=\"UTF-8\""));
Expand Down
2 changes: 1 addition & 1 deletion td/telegram/TdDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ Status TdDb::init_sqlite(int32 scheduler_id, const TdParameters &parameters, DbK
}

sqlite_path_ = sql_database_path;
TRY_RESULT(db_instance, SqliteDb::change_key(sqlite_path_, key, old_key));
TRY_RESULT(db_instance, SqliteDb::change_key(sqlite_path_, true, key, old_key));
sql_connection_ = std::make_shared<SqliteConnectionSafe>(sql_database_path, key, db_instance.get_cipher_version());
sql_connection_->set(std::move(db_instance));
auto &db = sql_connection_->get();
Expand Down
9 changes: 2 additions & 7 deletions tddb/td/db/SqliteConnectionSafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@ namespace td {
SqliteConnectionSafe::SqliteConnectionSafe(string path, DbKey key, optional<int32> cipher_version)
: path_(std::move(path))
, lsls_connection_([path = path_, key = std::move(key), cipher_version = std::move(cipher_version)] {
auto r_db = SqliteDb::open_with_key(path, key, cipher_version.copy());
auto r_db = SqliteDb::open_with_key(path, false, key, cipher_version.copy());
if (r_db.is_error()) {
auto r_stat = stat(path);
if (r_stat.is_error()) {
LOG(FATAL) << "Can't open database (" << r_stat.error() << "): " << r_db.error().message();
} else {
LOG(FATAL) << "Can't open database of size " << r_stat.ok().size_ << ": " << r_db.error().message();
}
LOG(FATAL) << "Can't open database: " << r_db.error().message();
}
auto db = r_db.move_as_ok();
db.exec("PRAGMA synchronous=NORMAL").ensure();
Expand Down
35 changes: 22 additions & 13 deletions tddb/td/db/SqliteDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,19 @@ string db_key_to_sqlcipher_key(const DbKey &db_key) {

SqliteDb::~SqliteDb() = default;

Status SqliteDb::init(CSlice path) {
Status SqliteDb::init(CSlice path, bool allow_creation) {
// if database does not exist, delete all other files which could have been left from the old database
bool is_db_exists = stat(path).is_ok();
if (!is_db_exists) {
if (!allow_creation) {
LOG(FATAL) << "Database was deleted during execution and can't be recreated";
}
TRY_STATUS(destroy(path));
}

sqlite3 *db;
CHECK(sqlite3_threadsafe() != 0);
int rc = sqlite3_open_v2(path.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE /*| SQLITE_OPEN_SHAREDCACHE*/,
nullptr);
int rc = sqlite3_open_v2(path.c_str(), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr);
if (rc != SQLITE_OK) {
auto res = detail::RawSqliteDb::last_error(db, path);
sqlite3_close(db);
Expand All @@ -95,6 +97,7 @@ Status SqliteDb::init(CSlice path) {
static void trace_callback(void *ptr, const char *query) {
LOG(ERROR) << query;
}

static int trace_v2_callback(unsigned code, void *ctx, void *p_raw, void *x_raw) {
CHECK(code == SQLITE_TRACE_STMT);
auto x = static_cast<const char *>(x_raw);
Expand All @@ -106,6 +109,7 @@ static int trace_v2_callback(unsigned code, void *ctx, void *p_raw, void *x_raw)

return 0;
}

void SqliteDb::trace(bool flag) {
sqlite3_trace_v2(raw_->db(), SQLITE_TRACE_STMT, flag ? trace_v2_callback : nullptr, nullptr);
}
Expand Down Expand Up @@ -136,6 +140,7 @@ Result<bool> SqliteDb::has_table(Slice table) {
auto cnt = stmt.view_int32(0);
return cnt == 1;
}

Result<string> SqliteDb::get_pragma(Slice name) {
TRY_RESULT(stmt, get_statement(PSLICE() << "PRAGMA " << name));
TRY_STATUS(stmt.step());
Expand All @@ -145,6 +150,7 @@ Result<string> SqliteDb::get_pragma(Slice name) {
CHECK(!stmt.can_step());
return std::move(res);
}

Result<string> SqliteDb::get_pragma_string(Slice name) {
TRY_RESULT(stmt, get_statement(PSLICE() << "PRAGMA " << name));
TRY_STATUS(stmt.step());
Expand Down Expand Up @@ -191,17 +197,19 @@ Status SqliteDb::check_encryption() {
return status;
}

Result<SqliteDb> SqliteDb::open_with_key(CSlice path, const DbKey &db_key, optional<int32> cipher_version) {
auto res = do_open_with_key(path, db_key, cipher_version ? cipher_version.value() : 0);
Result<SqliteDb> SqliteDb::open_with_key(CSlice path, bool allow_creation, const DbKey &db_key,
optional<int32> cipher_version) {
auto res = do_open_with_key(path, allow_creation, db_key, cipher_version ? cipher_version.value() : 0);
if (res.is_error() && !cipher_version && !db_key.is_empty()) {
return do_open_with_key(path, db_key, 3);
return do_open_with_key(path, false, db_key, 3);
}
return res;
}

Result<SqliteDb> SqliteDb::do_open_with_key(CSlice path, const DbKey &db_key, int32 cipher_version) {
Result<SqliteDb> SqliteDb::do_open_with_key(CSlice path, bool allow_creation, const DbKey &db_key,
int32 cipher_version) {
SqliteDb db;
TRY_STATUS(db.init(path));
TRY_STATUS(db.init(path, allow_creation));
if (!db_key.is_empty()) {
if (db.check_encryption().is_ok()) {
return Status::Error(PSLICE() << "No key is needed for database \"" << path << '"');
Expand All @@ -226,18 +234,19 @@ optional<int32> SqliteDb::get_cipher_version() const {
return raw_->get_cipher_version();
}

Result<SqliteDb> SqliteDb::change_key(CSlice path, const DbKey &new_db_key, const DbKey &old_db_key) {
Result<SqliteDb> SqliteDb::change_key(CSlice path, bool allow_creation, const DbKey &new_db_key,
const DbKey &old_db_key) {
PerfWarningTimer perf("change key", 0.001);

// fast path
{
auto r_db = open_with_key(path, new_db_key);
auto r_db = open_with_key(path, allow_creation, new_db_key);
if (r_db.is_ok()) {
return r_db;
}
}

TRY_RESULT(db, open_with_key(path, old_db_key));
TRY_RESULT(db, open_with_key(path, false, old_db_key));
TRY_RESULT(user_version, db.user_version());
auto new_key = db_key_to_sqlcipher_key(new_db_key);
if (old_db_key.is_empty() && !new_db_key.is_empty()) {
Expand Down Expand Up @@ -272,8 +281,8 @@ Result<SqliteDb> SqliteDb::change_key(CSlice path, const DbKey &new_db_key, cons
TRY_STATUS(db.exec(PSLICE() << "PRAGMA rekey = " << new_key));
}

TRY_RESULT(new_db, open_with_key(path, new_db_key));
LOG_CHECK(new_db.user_version().ok() == user_version) << new_db.user_version().ok() << " " << user_version;
TRY_RESULT(new_db, open_with_key(path, false, new_db_key));
CHECK(new_db.user_version().ok() == user_version);
return std::move(new_db);
}
Status SqliteDb::destroy(Slice path) {
Expand Down
10 changes: 6 additions & 4 deletions tddb/td/db/SqliteDb.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ class SqliteDb {
static Status destroy(Slice path) TD_WARN_UNUSED_RESULT;

// we can't change the key on the fly, so static functions are more than enough
static Result<SqliteDb> open_with_key(CSlice path, const DbKey &db_key, optional<int32> cipher_version = {});
static Result<SqliteDb> change_key(CSlice path, const DbKey &new_db_key, const DbKey &old_db_key);
static Result<SqliteDb> open_with_key(CSlice path, bool allow_creation, const DbKey &db_key,
optional<int32> cipher_version = {});
static Result<SqliteDb> change_key(CSlice path, bool allow_creation, const DbKey &new_db_key,
const DbKey &old_db_key);

Status last_error();

Expand All @@ -80,10 +82,10 @@ class SqliteDb {
std::shared_ptr<detail::RawSqliteDb> raw_;
bool enable_logging_ = false;

Status init(CSlice path) TD_WARN_UNUSED_RESULT;
Status init(CSlice path, bool allow_creation) TD_WARN_UNUSED_RESULT;

Status check_encryption();
static Result<SqliteDb> do_open_with_key(CSlice path, const DbKey &db_key, int32 cipher_version);
static Result<SqliteDb> do_open_with_key(CSlice path, bool allow_creation, const DbKey &db_key, int32 cipher_version);
void set_cipher_version(int32 cipher_version);
};

Expand Down
56 changes: 30 additions & 26 deletions test/db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ TEST(DB, binlog_encryption) {
TEST(DB, sqlite_lfs) {
string path = "test_sqlite_db";
SqliteDb::destroy(path).ignore();
auto db = SqliteDb::open_with_key(path, DbKey::empty()).move_as_ok();
auto db = SqliteDb::open_with_key(path, true, DbKey::empty()).move_as_ok();
db.exec("PRAGMA journal_mode=WAL").ensure();
db.exec("PRAGMA user_version").ensure();
SqliteDb::destroy(path).ignore();
}

TEST(DB, sqlite_encryption) {
Expand All @@ -151,48 +152,49 @@ TEST(DB, sqlite_encryption) {
auto tomato = DbKey::raw_key(string(32, 'a'));

{
auto db = SqliteDb::open_with_key(path, empty).move_as_ok();
auto db = SqliteDb::open_with_key(path, true, empty).move_as_ok();
db.set_user_version(123).ensure();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
kv.set("a", "b");
}
SqliteDb::open_with_key(path, cucumber).ensure_error(); // key was set...
SqliteDb::open_with_key(path, false, cucumber).ensure_error();

SqliteDb::change_key(path, cucumber, empty).ensure();
SqliteDb::change_key(path, false, cucumber, empty).ensure();

SqliteDb::open_with_key(path, tomato).ensure_error();
SqliteDb::open_with_key(path, false, tomato).ensure_error();
{
auto db = SqliteDb::open_with_key(path, cucumber).move_as_ok();
auto db = SqliteDb::open_with_key(path, false, cucumber).move_as_ok();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
CHECK(kv.get("a") == "b");
CHECK(db.user_version().ok() == 123);
}

SqliteDb::change_key(path, tomato, cucumber).ensure();
SqliteDb::change_key(path, tomato, cucumber).ensure();
SqliteDb::change_key(path, false, tomato, cucumber).ensure();
SqliteDb::change_key(path, false, tomato, cucumber).ensure();

SqliteDb::open_with_key(path, cucumber).ensure_error();
SqliteDb::open_with_key(path, false, cucumber).ensure_error();
{
auto db = SqliteDb::open_with_key(path, tomato).move_as_ok();
auto db = SqliteDb::open_with_key(path, false, tomato).move_as_ok();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
CHECK(kv.get("a") == "b");
CHECK(db.user_version().ok() == 123);
}

SqliteDb::change_key(path, empty, tomato).ensure();
SqliteDb::change_key(path, empty, tomato).ensure();
SqliteDb::change_key(path, false, empty, tomato).ensure();
SqliteDb::change_key(path, false, empty, tomato).ensure();

{
auto db = SqliteDb::open_with_key(path, empty).move_as_ok();
auto db = SqliteDb::open_with_key(path, false, empty).move_as_ok();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
CHECK(kv.get("a") == "b");
CHECK(db.user_version().ok() == 123);
}
SqliteDb::open_with_key(path, cucumber).ensure_error();
SqliteDb::open_with_key(path, false, cucumber).ensure_error();
SqliteDb::destroy(path).ignore();
}

TEST(DB, sqlite_encryption_migrate_v3) {
Expand All @@ -203,8 +205,7 @@ TEST(DB, sqlite_encryption_migrate_v3) {
if (false) {
// sqlite_sample_db was generated by the following code using SQLCipher based on SQLite 3.15.2
{
SqliteDb::change_key(path, cucumber, empty).ensure();
auto db = SqliteDb::open_with_key(path, cucumber).move_as_ok();
auto db = SqliteDb::change_key(path, true, cucumber, empty).move_as_ok();
db.set_user_version(123).ensure();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
Expand All @@ -214,12 +215,13 @@ TEST(DB, sqlite_encryption_migrate_v3) {
}
write_file(path, base64_decode(Slice(sqlite_sample_db_v3, sqlite_sample_db_v3_size)).move_as_ok()).ensure();
{
auto db = SqliteDb::open_with_key(path, cucumber).move_as_ok();
auto db = SqliteDb::open_with_key(path, true, cucumber).move_as_ok();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
CHECK(kv.get("hello") == "world");
CHECK(db.user_version().ok() == 123);
}
SqliteDb::destroy(path).ignore();
}

TEST(DB, sqlite_encryption_migrate_v4) {
Expand All @@ -230,8 +232,7 @@ TEST(DB, sqlite_encryption_migrate_v4) {
if (false) {
// sqlite_sample_db was generated by the following code using SQLCipher 4.4.0
{
SqliteDb::change_key(path, cucumber, empty).ensure();
auto db = SqliteDb::open_with_key(path, cucumber).move_as_ok();
auto db = SqliteDb::change_key(path, true, cucumber, empty).move_as_ok();
db.set_user_version(123).ensure();
auto kv = SqliteKeyValue();
kv.init_with_connection(db.clone(), "kv").ensure();
Expand All @@ -241,7 +242,7 @@ TEST(DB, sqlite_encryption_migrate_v4) {
}
write_file(path, base64_decode(Slice(sqlite_sample_db_v4, sqlite_sample_db_v4_size)).move_as_ok()).ensure();
{
auto r_db = SqliteDb::open_with_key(path, cucumber);
auto r_db = SqliteDb::open_with_key(path, true, cucumber);
if (r_db.is_error()) {
LOG(ERROR) << r_db.error();
return;
Expand All @@ -256,6 +257,7 @@ TEST(DB, sqlite_encryption_migrate_v4) {
CHECK(db.user_version().ok() == 123);
}
}
SqliteDb::destroy(path).ignore();
}

using SeqNo = uint64;
Expand Down Expand Up @@ -377,9 +379,9 @@ TEST(DB, key_value) {
new_kv.impl().init(new_kv_name.str()).ensure();

QueryHandler<SqliteKeyValue> sqlite_kv;
CSlice name = "test_sqlite_kv";
SqliteDb::destroy(name).ignore();
auto db = SqliteDb::open_with_key(name, DbKey::empty()).move_as_ok();
CSlice path = "test_sqlite_kv";
SqliteDb::destroy(path).ignore();
auto db = SqliteDb::open_with_key(path, true, DbKey::empty()).move_as_ok();
sqlite_kv.impl().init_with_connection(std::move(db), "KV").ensure();

int cnt = 0;
Expand All @@ -402,6 +404,7 @@ TEST(DB, key_value) {
new_kv.impl().init(new_kv_name.str()).ensure();
}
}
SqliteDb::destroy(path).ignore();
}

#if !TD_THREAD_UNSUPPORTED
Expand Down Expand Up @@ -516,9 +519,9 @@ TEST(DB, persistent_key_value) {
SET_VERBOSITY_LEVEL(VERBOSITY_NAME(ERROR));
std::vector<std::string> keys;
std::vector<std::string> values;
CSlice name = "test_pmc";
Binlog::destroy(name).ignore();
SqliteDb::destroy(name).ignore();
CSlice path = "test_pmc";
Binlog::destroy(path).ignore();
SqliteDb::destroy(path).ignore();

for (int i = 0; i < 100; i++) {
keys.push_back(rand_string('a', 'b', Random::fast(1, 10)));
Expand Down Expand Up @@ -675,4 +678,5 @@ TEST(DB, persistent_key_value) {
pos[best]++;
}
}
SqliteDb::destroy(path).ignore();
}

0 comments on commit a0cc1be

Please sign in to comment.