Skip to content

Commit

Permalink
TdDb: reuse derived sqlcipher version
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 28a94a5dee34f47850deb5cf1ca41e92e24bf648
  • Loading branch information
arseny30 committed Aug 14, 2020
1 parent 45bfb1b commit 28596f1
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 15 deletions.
2 changes: 1 addition & 1 deletion td/telegram/TdDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Status TdDb::init_sqlite(int32 scheduler_id, const TdParameters &parameters, DbK

sqlite_path_ = sql_database_path;
TRY_RESULT(db_instance, SqliteDb::change_key(sqlite_path_, key, old_key));
sql_connection_ = std::make_shared<SqliteConnectionSafe>(sql_database_path, key);
sql_connection_ = std::make_shared<SqliteConnectionSafe>(sql_database_path, key, db_instance.get_cipher_version());
sql_connection_->set(std::move(db_instance));
auto &db = sql_connection_->get();

Expand Down
7 changes: 4 additions & 3 deletions tddb/td/db/SqliteConnectionSafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

namespace td {

SqliteConnectionSafe::SqliteConnectionSafe(string path, DbKey key)
: path_(std::move(path)), lsls_connection_([path = path_, key = std::move(key)] {
auto r_db = SqliteDb::open_with_key(path, key);
SqliteConnectionSafe::SqliteConnectionSafe(string path, DbKey key, optional<int32> cipher_version)
: path_(std::move(path))
, lsls_connection_([path = path_, key = std::move(key), cipher_version = std::move(cipher_version)] {
auto r_db = SqliteDb::open_with_key(path, key, cipher_version.copy());
if (r_db.is_error()) {
auto r_stat = stat(path);
if (r_stat.is_error()) {
Expand Down
3 changes: 2 additions & 1 deletion tddb/td/db/SqliteConnectionSafe.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
#include "td/db/SqliteDb.h"

#include "td/utils/common.h"
#include "td/utils/optional.h"

namespace td {

class SqliteConnectionSafe {
public:
SqliteConnectionSafe() = default;
explicit SqliteConnectionSafe(string path, DbKey key = DbKey::empty());
explicit SqliteConnectionSafe(string path, DbKey key = DbKey::empty(), optional<int32> cipher_version = {});

SqliteDb &get();
void set(SqliteDb &&db);
Expand Down
25 changes: 17 additions & 8 deletions tddb/td/db/SqliteDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ Status SqliteDb::check_encryption() {
return status;
}

Result<SqliteDb> SqliteDb::open_with_key(CSlice path, const DbKey &db_key) {
auto res = do_open_with_key(path, db_key, false);
if (res.is_error()) {
return do_open_with_key(path, db_key, true);
Result<SqliteDb> SqliteDb::open_with_key(CSlice path, const DbKey &db_key, optional<int32> cipher_version) {
auto res = do_open_with_key(path, db_key, cipher_version ? cipher_version.value() : 0);
if (res.is_error() && !cipher_version) {
return do_open_with_key(path, db_key, 3);
}
return res;
}

Result<SqliteDb> SqliteDb::do_open_with_key(CSlice path, const DbKey &db_key, bool with_cipher_migrate) {
Result<SqliteDb> SqliteDb::do_open_with_key(CSlice path, const DbKey &db_key, int32 cipher_version) {
SqliteDb db;
TRY_STATUS(db.init(path));
if (!db_key.is_empty()) {
Expand All @@ -188,15 +188,24 @@ Result<SqliteDb> SqliteDb::do_open_with_key(CSlice path, const DbKey &db_key, bo
}
auto key = db_key_to_sqlcipher_key(db_key);
TRY_STATUS(db.exec(PSLICE() << "PRAGMA key = " << key));
if (with_cipher_migrate) {
LOG(INFO) << "Try Sqlcipher compatibility mode";
TRY_STATUS(db.exec("PRAGMA cipher_compatibility = 3"));
if (cipher_version != 0) {
LOG(INFO) << "Try Sqlcipher compatibility mode with version=" << cipher_version;
TRY_STATUS(db.exec(PSLICE() << "PRAGMA cipher_compatibility = " << cipher_version));
}
db.set_cipher_version(cipher_version);
}
TRY_STATUS_PREFIX(db.check_encryption(), "Can't open database: ");
return std::move(db);
}

void SqliteDb::set_cipher_version(int32 cipher_version) {
raw_->set_cipher_version(cipher_version);
}

optional<int32> SqliteDb::get_cipher_version() {
return raw_->get_cipher_version();
}

Result<SqliteDb> SqliteDb::change_key(CSlice path, const DbKey &new_db_key, const DbKey &old_db_key) {
PerfWarningTimer perf("change key", 0.001);

Expand Down
8 changes: 6 additions & 2 deletions tddb/td/db/SqliteDb.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "td/db/detail/RawSqliteDb.h"

#include "td/utils/logging.h"
#include "td/utils/optional.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"

Expand Down Expand Up @@ -61,7 +62,7 @@ class SqliteDb {
static Status destroy(Slice path) TD_WARN_UNUSED_RESULT;

// Anyway we can't change the key on the fly, so having static functions is more than enough
static Result<SqliteDb> open_with_key(CSlice path, const DbKey &db_key);
static Result<SqliteDb> open_with_key(CSlice path, const DbKey &db_key, optional<int32> cipher_version = {});
static Result<SqliteDb> change_key(CSlice path, const DbKey &new_db_key, const DbKey &old_db_key);

Status last_error();
Expand All @@ -77,14 +78,17 @@ class SqliteDb {
detail::RawSqliteDb::with_db_path(main_path, f);
}

optional<int32> get_cipher_version();

private:
explicit SqliteDb(std::shared_ptr<detail::RawSqliteDb> raw) : raw_(std::move(raw)) {
}
std::shared_ptr<detail::RawSqliteDb> raw_;
bool enable_logging_ = false;

Status check_encryption();
static Result<SqliteDb> do_open_with_key(CSlice path, const DbKey &db_key, bool with_cipher_migrate);
static Result<SqliteDb> do_open_with_key(CSlice path, const DbKey &db_key, int32 cipher_version);
void set_cipher_version(int32 cipher_version);
};

} // namespace td
10 changes: 10 additions & 0 deletions tddb/td/db/detail/RawSqliteDb.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#pragma once

#include "td/utils/logging.h"
#include "td/utils/optional.h"
#include "td/utils/Slice.h"
#include "td/utils/Status.h"

Expand Down Expand Up @@ -56,10 +57,19 @@ class RawSqliteDb {
return begin_cnt_ == 0;
}

void set_cipher_version(int32 cipher_version) {
cipher_version_ = cipher_version;
}

optional<int32> get_cipher_version() const {
return cipher_version_.copy();
}

private:
sqlite3 *db_;
std::string path_;
size_t begin_cnt_{0};
optional<int32> cipher_version_;
};

} // namespace detail
Expand Down

0 comments on commit 28596f1

Please sign in to comment.