Skip to content

Commit

Permalink
Add cache of secure values.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 54fcf89a89f28086705e12869e9dc777c2a86233
  • Loading branch information
levlam committed Aug 12, 2018
1 parent 3728c89 commit b07fc66
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 51 deletions.
2 changes: 1 addition & 1 deletion td/telegram/DialogDb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class DialogDbAsync : public DialogDbAsyncInterface {
do_flush();
sync_db_safe_.reset();
sync_db_ = nullptr;
promise.set_result(Unit());
promise.set_value(Unit());
stop();
}

Expand Down
113 changes: 68 additions & 45 deletions td/telegram/SecureManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@
#include "td/utils/optional.h"
#include "td/utils/Slice.h"

#include <mutex>

namespace td {

class GetSecureValue : public NetQueryCallback {
public:
GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type,
GetSecureValue(ActorShared<SecureManager> parent, std::string password, SecureValueType type,
Promise<SecureValueWithCredentials> promise);

private:
ActorShared<> parent_;
ActorShared<SecureManager> parent_;
string password_;
SecureValueType type_;
Promise<SecureValueWithCredentials> promise_;
Expand All @@ -49,10 +47,10 @@ class GetSecureValue : public NetQueryCallback {

class GetAllSecureValues : public NetQueryCallback {
public:
GetAllSecureValues(ActorShared<> parent, std::string password, Promise<TdApiSecureValues> promise);
GetAllSecureValues(ActorShared<SecureManager> parent, std::string password, Promise<TdApiSecureValues> promise);

private:
ActorShared<> parent_;
ActorShared<SecureManager> parent_;
string password_;
Promise<TdApiSecureValues> promise_;
optional<vector<EncryptedSecureValue>> encrypted_secure_values_;
Expand All @@ -68,11 +66,11 @@ class GetAllSecureValues : public NetQueryCallback {

class SetSecureValue : public NetQueryCallback {
public:
SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value,
SetSecureValue(ActorShared<SecureManager> parent, string password, SecureValue secure_value,
Promise<SecureValueWithCredentials> promise);

private:
ActorShared<> parent_;
ActorShared<SecureManager> parent_;
string password_;
SecureValue secure_value_;
Promise<SecureValueWithCredentials> promise_;
Expand Down Expand Up @@ -153,7 +151,7 @@ class SetSecureValueErrorsQuery : public Td::ResultHandler {
}
};

GetSecureValue::GetSecureValue(ActorShared<> parent, std::string password, SecureValueType type,
GetSecureValue::GetSecureValue(ActorShared<SecureManager> parent, std::string password, SecureValueType type,
Promise<SecureValueWithCredentials> promise)
: parent_(std::move(parent)), password_(std::move(password)), type_(type), promise_(std::move(promise)) {
}
Expand Down Expand Up @@ -188,7 +186,10 @@ void GetSecureValue::loop() {
if (r_secure_value.is_error()) {
return on_error(r_secure_value.move_as_error());
}
promise_.set_result(r_secure_value.move_as_ok());

send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok());

promise_.set_value(r_secure_value.move_as_ok());
stop();
}

Expand Down Expand Up @@ -226,7 +227,8 @@ void GetSecureValue::on_result(NetQueryPtr query) {
loop();
}

GetAllSecureValues::GetAllSecureValues(ActorShared<> parent, std::string password, Promise<TdApiSecureValues> promise)
GetAllSecureValues::GetAllSecureValues(ActorShared<SecureManager> parent, std::string password,
Promise<TdApiSecureValues> promise)
: parent_(std::move(parent)), password_(std::move(password)), promise_(std::move(promise)) {
}

Expand Down Expand Up @@ -260,9 +262,14 @@ void GetAllSecureValues::loop() {
if (r_secure_values.is_error()) {
return on_error(r_secure_values.move_as_error());
}

for (auto &secure_value : r_secure_values.ok()) {
send_closure(parent_, &SecureManager::on_get_secure_value, secure_value);
}

auto secure_values = transform(r_secure_values.move_as_ok(),
[](SecureValueWithCredentials &&value) { return std::move(value.value); });
promise_.set_result(get_passport_elements_object(file_manager, std::move(secure_values)));
promise_.set_value(get_passport_elements_object(file_manager, std::move(secure_values)));
stop();
}

Expand All @@ -287,7 +294,7 @@ void GetAllSecureValues::on_result(NetQueryPtr query) {
loop();
}

SetSecureValue::SetSecureValue(ActorShared<> parent, string password, SecureValue secure_value,
SetSecureValue::SetSecureValue(ActorShared<SecureManager> parent, string password, SecureValue secure_value,
Promise<SecureValueWithCredentials> promise)
: parent_(std::move(parent))
, password_(std::move(password))
Expand Down Expand Up @@ -561,7 +568,10 @@ void SetSecureValue::on_result(NetQueryPtr query) {
if (r_secure_value.is_error()) {
return on_error(r_secure_value.move_as_error());
}
promise_.set_result(r_secure_value.move_as_ok());

send_closure(parent_, &SecureManager::on_get_secure_value, r_secure_value.ok());

promise_.set_value(r_secure_value.move_as_ok());
stop();
}

Expand All @@ -579,12 +589,12 @@ void SetSecureValue::merge(FileManager *file_manager, FileId file_id, EncryptedS

class DeleteSecureValue : public NetQueryCallback {
public:
DeleteSecureValue(ActorShared<> parent, SecureValueType type, Promise<Unit> promise)
DeleteSecureValue(ActorShared<SecureManager> parent, SecureValueType type, Promise<Unit> promise)
: parent_(std::move(parent)), type_(std::move(type)), promise_(std::move(promise)) {
}

private:
ActorShared<> parent_;
ActorShared<SecureManager> parent_;
SecureValueType type_;
Promise<Unit> promise_;

Expand All @@ -609,8 +619,9 @@ class DeleteSecureValue : public NetQueryCallback {

class GetPassportAuthorizationForm : public NetQueryCallback {
public:
GetPassportAuthorizationForm(ActorShared<> parent, string password, int32 authorization_form_id, UserId bot_user_id,
string scope, string public_key, Promise<TdApiAuthorizationForm> promise)
GetPassportAuthorizationForm(ActorShared<SecureManager> parent, string password, int32 authorization_form_id,
UserId bot_user_id, string scope, string public_key,
Promise<TdApiAuthorizationForm> promise)
: parent_(std::move(parent))
, password_(std::move(password))
, authorization_form_id_(authorization_form_id)
Expand All @@ -621,7 +632,7 @@ class GetPassportAuthorizationForm : public NetQueryCallback {
}

private:
ActorShared<> parent_;
ActorShared<SecureManager> parent_;
string password_;
int32 authorization_form_id_;
UserId bot_user_id_;
Expand Down Expand Up @@ -844,27 +855,41 @@ void SecureManager::get_secure_value(std::string password, SecureValueType type,
if (r_secure_value.is_error()) {
return promise.set_error(r_secure_value.move_as_error());
}

auto *file_manager = G()->td().get_actor_unsafe()->file_manager_.get();
if (file_manager == nullptr) {
return promise.set_value(nullptr);
}
auto r_passport_element = get_passport_element_object(file_manager, r_secure_value.move_as_ok().value);
if (r_passport_element.is_error()) {
LOG(ERROR) << "Failed to get passport element object: " << r_passport_element.error();
return promise.set_value(nullptr);
}
promise.set_value(r_passport_element.move_as_ok());
});
do_get_secure_value(std::move(password), type, std::move(new_promise));
do_get_secure_value(std::move(password), type, false, std::move(new_promise));
}

void SecureManager::do_get_secure_value(std::string password, SecureValueType type,
void SecureManager::do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache,
Promise<SecureValueWithCredentials> promise) {
if (allow_from_cache && secure_value_cache_.count(type)) {
// TODO check password?
return promise.set_value(SecureValueWithCredentials(secure_value_cache_[type]));
}

refcnt_++;
create_actor<GetSecureValue>("GetSecureValue", actor_shared(), std::move(password), type, std::move(promise))
create_actor<GetSecureValue>("GetSecureValue", actor_shared(this), std::move(password), type, std::move(promise))
.release();
}

void SecureManager::on_get_secure_value(SecureValueWithCredentials value) {
auto type = value.value.type;
secure_value_cache_[type] = std::move(value);
}

void SecureManager::get_all_secure_values(std::string password, Promise<TdApiSecureValues> promise) {
refcnt_++;
create_actor<GetAllSecureValues>("GetAllSecureValues", actor_shared(), std::move(password), std::move(promise))
create_actor<GetAllSecureValues>("GetAllSecureValues", actor_shared(this), std::move(password), std::move(promise))
.release();
}

Expand All @@ -884,8 +909,8 @@ void SecureManager::set_secure_value(string password, SecureValue secure_value,
}
promise.set_value(r_passport_element.move_as_ok());
});
set_secure_value_queries_[type] = create_actor<SetSecureValue>("SetSecureValue", actor_shared(), std::move(password),
std::move(secure_value), std::move(new_promise));
set_secure_value_queries_[type] = create_actor<SetSecureValue>(
"SetSecureValue", actor_shared(this), std::move(password), std::move(secure_value), std::move(new_promise));
}

void SecureManager::delete_secure_value(SecureValueType type, Promise<Unit> promise) {
Expand All @@ -894,14 +919,15 @@ void SecureManager::delete_secure_value(SecureValueType type, Promise<Unit> prom
[actor_id = actor_id(this), type, promise = std::move(promise)](Result<Unit> result) mutable {
send_closure(actor_id, &SecureManager::on_delete_secure_value, type, std::move(promise), std::move(result));
});
create_actor<DeleteSecureValue>("DeleteSecureValue", actor_shared(), type, std::move(new_promise)).release();
create_actor<DeleteSecureValue>("DeleteSecureValue", actor_shared(this), type, std::move(new_promise)).release();
}

void SecureManager::on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result) {
if (result.is_error()) {
return promise.set_error(result.move_as_error());
}

secure_value_cache_.erase(type);
promise.set_value(Unit());
}

Expand Down Expand Up @@ -1005,15 +1031,15 @@ void SecureManager::get_passport_authorization_form(string password, UserId bot_
string public_key, string payload,
Promise<TdApiAuthorizationForm> promise) {
refcnt_++;
auto authorization_form_id = ++authorization_form_id_;
auto authorization_form_id = ++max_authorization_form_id_;
authorization_forms_[authorization_form_id] = AuthorizationForm{bot_user_id, scope, public_key, payload, false};
auto new_promise =
PromiseCreator::lambda([actor_id = actor_id(this), authorization_form_id, promise = std::move(promise)](
Result<TdApiAuthorizationForm> r_authorization_form) mutable {
send_closure(actor_id, &SecureManager::on_get_passport_authorization_form, authorization_form_id,
std::move(promise), std::move(r_authorization_form));
});
create_actor<GetPassportAuthorizationForm>("GetPassportAuthorizationForm", actor_shared(), std::move(password),
create_actor<GetPassportAuthorizationForm>("GetPassportAuthorizationForm", actor_shared(this), std::move(password),
authorization_form_id, bot_user_id, std::move(scope),
std::move(public_key), std::move(new_promise))
.release();
Expand Down Expand Up @@ -1050,31 +1076,28 @@ void SecureManager::send_passport_authorization_form(string password, int32 auth
}

struct JoinPromise {
std::mutex mutex_;
Promise<std::vector<SecureValueCredentials>> promise_;
std::vector<SecureValueCredentials> credentials_;
int wait_cnt_{0};
};

auto join = std::make_shared<JoinPromise>();
std::lock_guard<std::mutex> guard(join->mutex_);
for (auto type : types) {
join->wait_cnt_++;
do_get_secure_value(password, type,
PromiseCreator::lambda([join](Result<SecureValueWithCredentials> r_secure_value) {
std::lock_guard<std::mutex> guard(join->mutex_);
if (!join->promise_) {
return;
}
if (r_secure_value.is_error()) {
return join->promise_.set_error(r_secure_value.move_as_error());
}
join->credentials_.push_back(r_secure_value.move_as_ok().credentials);
join->wait_cnt_--;
if (join->wait_cnt_ == 0) {
join->promise_.set_value(std::move(join->credentials_));
}
}));
send_closure_later(actor_id(this), &SecureManager::do_get_secure_value, password, type, true,
PromiseCreator::lambda([join](Result<SecureValueWithCredentials> r_secure_value) {
if (!join->promise_) {
return;
}
if (r_secure_value.is_error()) {
return join->promise_.set_error(r_secure_value.move_as_error());
}
join->credentials_.push_back(r_secure_value.move_as_ok().credentials);
join->wait_cnt_--;
if (join->wait_cnt_ == 0) {
join->promise_.set_value(std::move(join->credentials_));
}
}));
}
join->promise_ =
PromiseCreator::lambda([promise = std::move(promise), actor_id = actor_id(this),
Expand Down
14 changes: 9 additions & 5 deletions td/telegram/SecureManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#include "td/utils/Container.h"
#include "td/utils/Status.h"

#include <map>
#include <memory>
#include <unordered_map>

namespace td {

Expand All @@ -41,6 +41,8 @@ class SecureManager : public NetQueryCallback {
void set_secure_value_errors(Td *td, tl_object_ptr<telegram_api::InputUser> input_user,
vector<tl_object_ptr<td_api::inputPassportElementError>> errors, Promise<Unit> promise);

void on_get_secure_value(SecureValueWithCredentials value);

void get_passport_authorization_form(string password, UserId bot_user_id, string scope, string public_key,
string payload, Promise<TdApiAuthorizationForm> promise);
void send_passport_authorization_form(string password, int32 authorization_form_id,
Expand All @@ -49,7 +51,8 @@ class SecureManager : public NetQueryCallback {
private:
ActorShared<> parent_;
int32 refcnt_{1};
std::map<SecureValueType, ActorOwn<>> set_secure_value_queries_;
std::unordered_map<SecureValueType, ActorOwn<>> set_secure_value_queries_;
std::unordered_map<SecureValueType, SecureValueWithCredentials> secure_value_cache_;

struct AuthorizationForm {
UserId bot_user_id;
Expand All @@ -59,13 +62,14 @@ class SecureManager : public NetQueryCallback {
bool is_received;
};

std::map<int32, AuthorizationForm> authorization_forms_;
int32 authorization_form_id_{0};
std::unordered_map<int32, AuthorizationForm> authorization_forms_;
int32 max_authorization_form_id_{0};

void hangup() override;
void hangup_shared() override;
void dec_refcnt();
void do_get_secure_value(std::string password, SecureValueType type, Promise<SecureValueWithCredentials> promise);
void do_get_secure_value(std::string password, SecureValueType type, bool allow_from_cache,
Promise<SecureValueWithCredentials> promise);
void on_delete_secure_value(SecureValueType type, Promise<Unit> promise, Result<Unit> result);
void on_get_passport_authorization_form(int32 authorization_form_id, Promise<TdApiAuthorizationForm> promise,
Result<TdApiAuthorizationForm> r_authorization_form);
Expand Down

0 comments on commit b07fc66

Please sign in to comment.