Skip to content

Commit

Permalink
Fix TransparentProxy::Callback.
Browse files Browse the repository at this point in the history
  • Loading branch information
levlam committed Oct 22, 2021
1 parent 03c6d53 commit 0c1469f
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 42 deletions.
11 changes: 6 additions & 5 deletions td/mtproto/RawConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "td/net/DarwinHttp.h"
#endif

#include "td/utils/BufferedFd.h"
#include "td/utils/format.h"
#include "td/utils/logging.h"
#include "td/utils/misc.h"
Expand All @@ -35,8 +34,9 @@ namespace mtproto {

class RawConnectionDefault final : public RawConnection {
public:
RawConnectionDefault(SocketFd socket_fd, TransportType transport_type, unique_ptr<StatsCallback> stats_callback)
: socket_fd_(std::move(socket_fd))
RawConnectionDefault(BufferedFd<SocketFd> buffered_socket_fd, TransportType transport_type,
unique_ptr<StatsCallback> stats_callback)
: socket_fd_(std::move(buffered_socket_fd))
, transport_(create_transport(std::move(transport_type)))
, stats_callback_(std::move(stats_callback)) {
transport_->init(&socket_fd_.input_buffer(), &socket_fd_.output_buffer());
Expand Down Expand Up @@ -450,12 +450,13 @@ class RawConnectionHttp final : public RawConnection {
};
#endif

unique_ptr<RawConnection> RawConnection::create(IPAddress ip_address, SocketFd socket_fd, TransportType transport_type,
unique_ptr<RawConnection> RawConnection::create(IPAddress ip_address, BufferedFd<SocketFd> buffered_socket_fd,
TransportType transport_type,
unique_ptr<StatsCallback> stats_callback) {
#if TD_DARWIN_WATCH_OS
return td::make_unique<RawConnectionHttp>(std::move(ip_address), std::move(stats_callback));
#else
return td::make_unique<RawConnectionDefault>(std::move(socket_fd), std::move(transport_type),
return td::make_unique<RawConnectionDefault>(std::move(buffered_socket_fd), std::move(transport_type),
std::move(stats_callback));
#endif
}
Expand Down
5 changes: 3 additions & 2 deletions td/mtproto/RawConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "td/mtproto/TransportType.h"

#include "td/utils/buffer.h"
#include "td/utils/BufferedFd.h"
#include "td/utils/common.h"
#include "td/utils/port/detail/PollableFd.h"
#include "td/utils/port/IPAddress.h"
Expand Down Expand Up @@ -40,8 +41,8 @@ class RawConnection {
RawConnection &operator=(const RawConnection &) = delete;
virtual ~RawConnection() = default;

static unique_ptr<RawConnection> create(IPAddress ip_address, SocketFd socket_fd, TransportType transport_type,
unique_ptr<StatsCallback> stats_callback);
static unique_ptr<RawConnection> create(IPAddress ip_address, BufferedFd<SocketFd> buffered_socket_fd,
TransportType transport_type, unique_ptr<StatsCallback> stats_callback);

virtual void set_connection_token(ConnectionManager::ConnectionToken connection_token) = 0;

Expand Down
2 changes: 1 addition & 1 deletion td/telegram/Td.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ class TestProxyRequest final : public RequestOnceActor {
auto handshake = make_unique<mtproto::AuthKeyHandshake>(dc_id_, 3600);
auto data = r_data.move_as_ok();
auto raw_connection =
mtproto::RawConnection::create(data.ip_address, std::move(data.socket_fd), get_transport(), nullptr);
mtproto::RawConnection::create(data.ip_address, std::move(data.buffered_socket_fd), get_transport(), nullptr);
child_ = create_actor<mtproto::HandshakeActor>(
"HandshakeActor", std::move(handshake), std::move(raw_connection), make_unique<HandshakeContext>(), 10.0,
PromiseCreator::lambda([actor_id = actor_id(this)](Result<unique_ptr<mtproto::RawConnection>> raw_connection) {
Expand Down
44 changes: 23 additions & 21 deletions td/telegram/net/ConnectionCreator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,12 @@ void ConnectionCreator::ping_proxy(int32 proxy_id, Promise<double> promise) {
continue;
}

ping_proxy_socket_fd(std::move(ip_address), r_socket_fd.move_as_ok(), r_transport_type.move_as_ok(),
PSTRING() << info.option->get_ip_address(),
PromiseCreator::lambda([actor_id = actor_id(this), token](Result<double> result) {
send_closure(actor_id, &ConnectionCreator::on_ping_main_dc_result, token,
std::move(result));
}));
ping_proxy_buffered_socket_fd(std::move(ip_address), BufferedFd<SocketFd>(r_socket_fd.move_as_ok()),
r_transport_type.move_as_ok(), PSTRING() << info.option->get_ip_address(),
PromiseCreator::lambda([actor_id = actor_id(this), token](Result<double> result) {
send_closure(actor_id, &ConnectionCreator::on_ping_main_dc_result, token,
std::move(result));
}));
}
return;
}
Expand Down Expand Up @@ -375,8 +375,9 @@ void ConnectionCreator::ping_proxy_resolved(int32 proxy_id, IPAddress ip_address
if (r_connection_data.is_error()) {
return promise.set_error(Status::Error(400, r_connection_data.error().public_message()));
}
send_closure(actor_id, &ConnectionCreator::ping_proxy_socket_fd, ip_address,
r_connection_data.move_as_ok().socket_fd, std::move(transport_type), std::move(debug_str),
auto connection_data = r_connection_data.move_as_ok();
send_closure(actor_id, &ConnectionCreator::ping_proxy_buffered_socket_fd, ip_address,
std::move(connection_data.buffered_socket_fd), std::move(transport_type), std::move(debug_str),
std::move(promise));
});
CHECK(proxy.use_proxy());
Expand All @@ -389,12 +390,12 @@ void ConnectionCreator::ping_proxy_resolved(int32 proxy_id, IPAddress ip_address
}
}

void ConnectionCreator::ping_proxy_socket_fd(IPAddress ip_address, SocketFd socket_fd,
mtproto::TransportType transport_type, string debug_str,
Promise<double> promise) {
void ConnectionCreator::ping_proxy_buffered_socket_fd(IPAddress ip_address, BufferedFd<SocketFd> buffered_socket_fd,
mtproto::TransportType transport_type, string debug_str,
Promise<double> promise) {
auto token = next_token();
auto raw_connection =
mtproto::RawConnection::create(ip_address, std::move(socket_fd), std::move(transport_type), nullptr);
mtproto::RawConnection::create(ip_address, std::move(buffered_socket_fd), std::move(transport_type), nullptr);
children_[token] = {
false, create_ping_actor(debug_str, std::move(raw_connection), nullptr,
PromiseCreator::lambda([promise = std::move(promise)](
Expand Down Expand Up @@ -651,8 +652,9 @@ void ConnectionCreator::request_raw_connection_by_ip(IPAddress ip_address, mtpro
if (r_connection_data.is_error()) {
return promise.set_error(Status::Error(400, r_connection_data.error().public_message()));
}
auto raw_connection =
mtproto::RawConnection::create(ip_address, r_connection_data.move_as_ok().socket_fd, transport_type, nullptr);
auto connection_data = r_connection_data.move_as_ok();
auto raw_connection = mtproto::RawConnection::create(ip_address, std::move(connection_data.buffered_socket_fd),
transport_type, nullptr);
raw_connection->extra().extra = network_generation;
promise.set_value(std::move(raw_connection));
});
Expand Down Expand Up @@ -754,19 +756,19 @@ ActorOwn<> ConnectionCreator::prepare_connection(IPAddress ip_address, SocketFd
, use_connection_token_(use_connection_token)
, was_connected_(was_connected) {
}
void set_result(Result<SocketFd> result) final {
if (result.is_error()) {
void set_result(Result<BufferedFd<SocketFd>> r_buffered_socket_fd) final {
if (r_buffered_socket_fd.is_error()) {
if (use_connection_token_) {
connection_token_ = mtproto::ConnectionManager::ConnectionToken();
}
if (was_connected_ && stats_callback_) {
stats_callback_->on_error();
}
promise_.set_error(Status::Error(400, result.error().public_message()));
promise_.set_error(Status::Error(400, r_buffered_socket_fd.error().public_message()));
} else {
ConnectionData data;
data.ip_address = ip_address_;
data.socket_fd = result.move_as_ok();
data.buffered_socket_fd = r_buffered_socket_fd.move_as_ok();
data.connection_token = std::move(connection_token_);
data.stats_callback = std::move(stats_callback_);
promise_.set_value(std::move(data));
Expand All @@ -785,7 +787,7 @@ ActorOwn<> ConnectionCreator::prepare_connection(IPAddress ip_address, SocketFd
mtproto::ConnectionManager::ConnectionToken connection_token_;
IPAddress ip_address_;
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback_;
bool use_connection_token_;
bool use_connection_token_{false};
bool was_connected_{false};
};
VLOG(connections) << "Start "
Expand Down Expand Up @@ -814,7 +816,7 @@ ActorOwn<> ConnectionCreator::prepare_connection(IPAddress ip_address, SocketFd

ConnectionData data;
data.ip_address = ip_address;
data.socket_fd = std::move(socket_fd);
data.buffered_socket_fd = BufferedFd<SocketFd>(std::move(socket_fd));
data.stats_callback = std::move(stats_callback);
promise.set_result(std::move(data));
return {};
Expand Down Expand Up @@ -991,7 +993,7 @@ void ConnectionCreator::client_create_raw_connection(Result<ConnectionData> r_co

auto connection_data = r_connection_data.move_as_ok();
auto raw_connection =
mtproto::RawConnection::create(connection_data.ip_address, std::move(connection_data.socket_fd),
mtproto::RawConnection::create(connection_data.ip_address, std::move(connection_data.buffered_socket_fd),
std::move(transport_type), std::move(connection_data.stats_callback));
raw_connection->set_connection_token(std::move(connection_data.connection_token));

Expand Down
7 changes: 4 additions & 3 deletions td/telegram/net/ConnectionCreator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "td/actor/PromiseFuture.h"
#include "td/actor/SignalSlot.h"

#include "td/utils/BufferedFd.h"
#include "td/utils/common.h"
#include "td/utils/FloodControlStrict.h"
#include "td/utils/logging.h"
Expand Down Expand Up @@ -81,7 +82,7 @@ class ConnectionCreator final : public NetQueryCallback {

struct ConnectionData {
IPAddress ip_address;
SocketFd socket_fd;
BufferedFd<SocketFd> buffered_socket_fd;
mtproto::ConnectionManager::ConnectionToken connection_token;
unique_ptr<mtproto::RawConnection::StatsCallback> stats_callback;
};
Expand Down Expand Up @@ -246,8 +247,8 @@ class ConnectionCreator final : public NetQueryCallback {

void ping_proxy_resolved(int32 proxy_id, IPAddress ip_address, Promise<double> promise);

void ping_proxy_socket_fd(IPAddress ip_address, SocketFd socket_fd, mtproto::TransportType transport_type,
string debug_str, Promise<double> promise);
void ping_proxy_buffered_socket_fd(IPAddress ip_address, BufferedFd<SocketFd> buffered_socket_fd,
mtproto::TransportType transport_type, string debug_str, Promise<double> promise);

void on_ping_main_dc_result(uint64 token, Result<double> result);
};
Expand Down
2 changes: 1 addition & 1 deletion tdnet/td/net/TransparentProxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TransparentProxy : public Actor {
Callback &operator=(const Callback &) = delete;
virtual ~Callback() = default;

virtual void set_result(Result<SocketFd>) = 0;
virtual void set_result(Result<BufferedFd<SocketFd>> r_buffered_socket_fd) = 0;
virtual void on_connected() = 0;
};

Expand Down
18 changes: 9 additions & 9 deletions test/mtproto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class TestPingActor final : public Actor {
}

ping_connection_ = mtproto::PingConnection::create_req_pq(
mtproto::RawConnection::create(ip_address_, r_socket.move_as_ok(),
mtproto::RawConnection::create(ip_address_, BufferedFd<SocketFd>(r_socket.move_as_ok()),
mtproto::TransportType{mtproto::TransportType::Tcp, 0, mtproto::ProxySecret()},
nullptr),
3);
Expand Down Expand Up @@ -339,7 +339,7 @@ class HandshakeTestActor final : public Actor {
}

raw_connection_ = mtproto::RawConnection::create(
ip_address, r_socket.move_as_ok(),
ip_address, BufferedFd<SocketFd>(r_socket.move_as_ok()),
mtproto::TransportType{mtproto::TransportType::Tcp, 0, mtproto::ProxySecret()}, nullptr);
}
if (!wait_for_handshake_ && !handshake_) {
Expand Down Expand Up @@ -438,22 +438,22 @@ RegisterTest<Mtproto_handshake> mtproto_handshake("Mtproto_handshake");
class Socks5TestActor final : public Actor {
public:
void start_up() final {
auto promise = PromiseCreator::lambda([actor_id = actor_id(this)](Result<SocketFd> res) {
auto promise = PromiseCreator::lambda([actor_id = actor_id(this)](Result<BufferedFd<SocketFd>> res) {
send_closure(actor_id, &Socks5TestActor::on_result, std::move(res), false);
});

class Callback final : public TransparentProxy::Callback {
public:
explicit Callback(Promise<SocketFd> promise) : promise_(std::move(promise)) {
explicit Callback(Promise<BufferedFd<SocketFd>> promise) : promise_(std::move(promise)) {
}
void set_result(Result<SocketFd> result) final {
void set_result(Result<BufferedFd<SocketFd>> result) final {
promise_.set_result(std::move(result));
}
void on_connected() final {
}

private:
Promise<SocketFd> promise_;
Promise<BufferedFd<SocketFd>> promise_;
};

IPAddress socks5_ip;
Expand All @@ -470,7 +470,7 @@ class Socks5TestActor final : public Actor {
}

private:
void on_result(Result<SocketFd> res, bool dummy) {
void on_result(Result<BufferedFd<SocketFd>> res, bool dummy) {
res.ensure();
Scheduler::instance()->finish();
}
Expand Down Expand Up @@ -545,7 +545,7 @@ class FastPingTestActor final : public Actor {
}

auto raw_connection = mtproto::RawConnection::create(
ip_address, r_socket.move_as_ok(),
ip_address, BufferedFd<SocketFd>(r_socket.move_as_ok()),
mtproto::TransportType{mtproto::TransportType::Tcp, 0, mtproto::ProxySecret()}, nullptr);
auto handshake = make_unique<mtproto::AuthKeyHandshake>(get_default_dc_id(), 60 * 100 /*temp*/);
create_actor<mtproto::HandshakeActor>(
Expand Down Expand Up @@ -676,7 +676,7 @@ TEST(Mtproto, TlsTransport) {
void start_up() final {
class Callback final : public TransparentProxy::Callback {
public:
void set_result(Result<SocketFd> result) final {
void set_result(Result<BufferedFd<SocketFd>> result) final {
if (result.is_ok()) {
LOG(ERROR) << "Unexpectedly succeeded to connect to MTProto proxy";
} else if (result.error().message() != "Response hash mismatch") {
Expand Down

0 comments on commit 0c1469f

Please sign in to comment.