Skip to content

Commit

Permalink
PollableFd: explicit sync_with_poll
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 71fa35a594816e84e372ebcfa9d0077a13f26a62
  • Loading branch information
arseny30 committed Jul 21, 2020
1 parent ceb49d0 commit 38ef3a7
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 44 deletions.
7 changes: 4 additions & 3 deletions benchmark/bench_http_server_cheat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,16 @@ class HelloWorld : public Actor {
}
}
Status do_loop() {
sync_with_poll(socket_fd_);
TRY_STATUS(read_loop());
TRY_STATUS(write_loop());
if (can_close(socket_fd_)) {
if (can_close_local(socket_fd_)) {
return Status::Error("CLOSE");
}
return Status::OK();
}
Status write_loop() {
while (can_write(socket_fd_) && write_pos_ < write_buf_.size()) {
while (can_write_local(socket_fd_) && write_pos_ < write_buf_.size()) {
TRY_RESULT(written, socket_fd_.write(Slice(write_buf_).substr(write_pos_)));
write_pos_ += written;
if (write_pos_ == write_buf_.size()) {
Expand All @@ -80,7 +81,7 @@ class HelloWorld : public Actor {
return Status::OK();
}
Status read_loop() {
while (can_read(socket_fd_)) {
while (can_read_local(socket_fd_)) {
TRY_RESULT(read_size, socket_fd_.read(MutableSlice(read_buf.data(), read_buf.size())));
for (size_t i = 0; i < read_size; i++) {
if (read_buf[i] == '\n') {
Expand Down
7 changes: 3 additions & 4 deletions benchmark/bench_http_server_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,18 @@ class HttpEchoConnection : public Actor {
}

void loop() override {
sync_with_poll(fd_);
auto status = [&] {
TRY_STATUS(loop_read());
TRY_STATUS(loop_write());
return Status::OK();
}();
if (status.is_error() || can_close(fd_)) {
if (status.is_error() || can_close_local(fd_)) {
stop();
}
}
Status loop_read() {
if (can_read(fd_)) {
TRY_STATUS(fd_.flush_read());
}
TRY_STATUS(fd_.flush_read());
while (true) {
TRY_RESULT(need, reader_.read_next(&query_));
if (need == 0) {
Expand Down
3 changes: 2 additions & 1 deletion td/mtproto/RawConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class RawConnection {
if (has_error_) {
return Status::Error("Connection has already failed");
}
sync_with_poll(socket_fd_);

// read/write
// EINVAL may be returned in linux kernel < 2.6.28. And on some new kernels too.
Expand All @@ -139,7 +140,7 @@ class RawConnection {
TRY_STATUS(flush_read(auth_key, callback));
TRY_STATUS(callback.before_write());
TRY_STATUS(flush_write());
if (can_close(socket_fd_)) {
if (can_close_local(socket_fd_)) {
return Status::Error("Connection closed");
}
return Status::OK();
Expand Down
9 changes: 5 additions & 4 deletions tdnet/td/net/HttpConnectionBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ void HttpConnectionBase::timeout_expired() {
stop();
}
void HttpConnectionBase::loop() {
if (can_read(fd_)) {
sync_with_poll(fd_);
if (can_read_local(fd_)) {
LOG(DEBUG) << "Can read from the connection";
auto r = fd_.flush_read();
if (r.is_error()) {
Expand Down Expand Up @@ -133,7 +134,7 @@ void HttpConnectionBase::loop() {

write_source_.wakeup();

if (can_write(fd_)) {
if (can_write_local(fd_)) {
LOG(DEBUG) << "Can write to the connection";
auto r = fd_.flush_write();
if (r.is_error()) {
Expand All @@ -146,7 +147,7 @@ void HttpConnectionBase::loop() {
}

Status pending_error;
if (fd_.get_poll_info().get_flags().has_pending_error()) {
if (fd_.get_poll_info().get_flags_local().has_pending_error()) {
pending_error = fd_.get_pending_error();
}
if (pending_error.is_ok() && write_sink_.status().is_error()) {
Expand All @@ -163,7 +164,7 @@ void HttpConnectionBase::loop() {
state_ = State::Close;
}

if (can_close(fd_)) {
if (can_close_local(fd_)) {
LOG(DEBUG) << "Can close the connection";
state_ = State::Close;
}
Expand Down
5 changes: 3 additions & 2 deletions tdnet/td/net/TcpListener.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void TcpListener::loop() {
if (server_fd_.empty()) {
start_up();
}
while (can_read(server_fd_)) {
sync_with_poll(server_fd_);
while (can_read_local(server_fd_)) {
auto r_socket_fd = server_fd_.accept();
if (r_socket_fd.is_error()) {
if (r_socket_fd.error().code() != -1) {
Expand All @@ -51,7 +52,7 @@ void TcpListener::loop() {
send_closure(callback_, &Callback::accept, r_socket_fd.move_as_ok());
}

if (can_close(server_fd_)) {
if (can_close_local(server_fd_)) {
stop();
}
}
Expand Down
6 changes: 4 additions & 2 deletions tdnet/td/net/TransparentProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ void TransparentProxy::start_up() {
VLOG(proxy) << "Begin to connect to proxy";
Scheduler::subscribe(fd_.get_poll_info().extract_pollable_fd(this));
set_timeout_in(10);
if (can_write(fd_)) {
sync_with_poll(fd_);
if (can_write_local(fd_)) {
loop();
}
}

void TransparentProxy::loop() {
sync_with_poll(fd_);
auto status = [&] {
TRY_STATUS(fd_.flush_read());
TRY_STATUS(loop_impl());
Expand All @@ -70,7 +72,7 @@ void TransparentProxy::loop() {
if (status.is_error()) {
on_error(std::move(status));
}
if (can_close(fd_)) {
if (can_close_local(fd_)) {
on_error(Status::Error("Connection closed"));
}
}
Expand Down
11 changes: 6 additions & 5 deletions tdutils/td/utils/BufferedFd.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ class BufferedFdBase : public FdT {
Result<size_t> flush_write() TD_WARN_UNUSED_RESULT;

bool need_flush_write(size_t at_least = 0) {
CHECK(write_);
write_->sync_with_writer();
return write_->size() > at_least;
return ready_for_flush_write() > at_least;
}
size_t ready_for_flush_write() {
CHECK(write_);
write_->sync_with_writer();
return write_->size();
}
void sync_with_poll() {
::td::sync_with_poll(*this);
}
void set_input_writer(ChainBufferWriter *read) {
read_ = read;
}
Expand Down Expand Up @@ -99,7 +100,7 @@ template <class FdT>
Result<size_t> BufferedFdBase<FdT>::flush_read(size_t max_read) {
CHECK(read_);
size_t result = 0;
while (::td::can_read(*this) && max_read) {
while (::td::can_read_local(*this) && max_read) {
MutableSlice slice = read_->prepare_append().truncate(max_read);
TRY_RESULT(x, FdT::read(slice));
slice.truncate(x);
Expand All @@ -115,7 +116,7 @@ Result<size_t> BufferedFdBase<FdT>::flush_write() {
// TODO: sync on demand
write_->sync_with_writer();
size_t result = 0;
while (!write_->empty() && ::td::can_write(*this)) {
while (!write_->empty() && ::td::can_write_local(*this)) {
constexpr size_t BUF_SIZE = 20;
IoSlice buf[BUF_SIZE];

Expand Down
7 changes: 5 additions & 2 deletions tdutils/td/utils/BufferedUdp.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@ class BufferedUdp : public UdpSocketFd {
}

#if TD_PORT_POSIX
void sync_with_poll() {
::td::sync_with_poll(*this);
}
Result<optional<UdpMessage>> receive() {
if (input_.empty() && can_read(*this)) {
if (input_.empty() && can_read_local(*this)) {
TRY_STATUS(flush_read_once());
}
if (input_.empty()) {
Expand All @@ -130,7 +133,7 @@ class BufferedUdp : public UdpSocketFd {

Status flush_send() {
Status status;
while (status.is_ok() && can_write(*this) && !output_.empty()) {
while (status.is_ok() && can_write_local(*this) && !output_.empty()) {
status = flush_send_once();
}
return status;
Expand Down
2 changes: 1 addition & 1 deletion tdutils/td/utils/port/ServerSocketFd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class ServerSocketFdImpl {
}

Status get_pending_error() {
if (!get_poll_info().get_flags().has_pending_error()) {
if (!get_poll_info().get_flags_local().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
Expand Down
6 changes: 3 additions & 3 deletions tdutils/td/utils/port/SocketFd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SocketFdImpl : private Iocp::Callback {
}

Result<size_t> read(MutableSlice slice) {
if (get_poll_info().get_flags().has_pending_error()) {
if (get_poll_info().get_flags_local().has_pending_error()) {
TRY_STATUS(get_pending_error());
}
input_reader_.sync_with_writer();
Expand Down Expand Up @@ -435,7 +435,7 @@ class SocketFdImpl {
}
}
Result<size_t> read(MutableSlice slice) {
if (get_poll_info().get_flags().has_pending_error()) {
if (get_poll_info().get_flags_local().has_pending_error()) {
TRY_STATUS(get_pending_error());
}
int native_fd = get_native_fd().socket();
Expand Down Expand Up @@ -482,7 +482,7 @@ class SocketFdImpl {
}
}
Status get_pending_error() {
if (!get_poll_info().get_flags().has_pending_error()) {
if (!get_poll_info().get_flags_local().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
Expand Down
6 changes: 4 additions & 2 deletions tdutils/td/utils/port/StdStreams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "td/utils/port/detail/Iocp.h"
#include "td/utils/port/detail/NativeFd.h"
#include "td/utils/port/PollFlags.h"
#include "td/utils/port/detail/PollableFd.h"
#include "td/utils/port/thread.h"
#include "td/utils/ScopeGuard.h"
#include "td/utils/Slice.h"
Expand Down Expand Up @@ -102,7 +103,7 @@ class BufferedStdinImpl : public Iocp::Callback {
}

Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
info_.get_flags();
info_.sync_with_poll();
info_.clear_flags(PollFlags::Read());
reader_.sync_with_writer();
return reader_.size();
Expand Down Expand Up @@ -196,7 +197,8 @@ class BufferedStdinImpl {

Result<size_t> flush_read(size_t max_read = std::numeric_limits<size_t>::max()) TD_WARN_UNUSED_RESULT {
size_t result = 0;
while (::td::can_read(*this) && max_read) {
::td::sync_with_poll(*this);
while (::td::can_read_local(*this) && max_read) {
MutableSlice slice = writer_.prepare_append().truncate(max_read);
TRY_RESULT(x, file_fd_.read(slice));
slice.truncate(x);
Expand Down
8 changes: 4 additions & 4 deletions tdutils/td/utils/port/UdpSocketFd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ class UdpSocketFdImpl {
return info_.native_fd();
}
Status get_pending_error() {
if (!get_poll_info().get_flags().has_pending_error()) {
if (!get_poll_info().get_flags_local().has_pending_error()) {
return Status::OK();
}
TRY_STATUS(detail::get_socket_pending_error(get_native_fd()));
Expand All @@ -487,7 +487,7 @@ class UdpSocketFdImpl {
Status receive_message(UdpSocketFd::InboundMessage &message, bool &is_received) {
is_received = false;
int flags = 0;
if (get_poll_info().get_flags().has_pending_error()) {
if (get_poll_info().get_flags_local().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
Expand Down Expand Up @@ -679,7 +679,7 @@ class UdpSocketFdImpl {
#endif
Status receive_messages_slow(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
cnt = 0;
while (cnt < messages.size() && get_poll_info().get_flags().can_read()) {
while (cnt < messages.size() && get_poll_info().get_flags_local().can_read()) {
auto &message = messages[cnt];
CHECK(!message.data.empty());
bool is_received;
Expand All @@ -694,7 +694,7 @@ class UdpSocketFdImpl {
Status receive_messages_fast(MutableSpan<UdpSocketFd::InboundMessage> messages, size_t &cnt) {
int flags = 0;
cnt = 0;
if (get_poll_info().get_flags().has_pending_error()) {
if (get_poll_info().get_flags_local().has_pending_error()) {
#ifdef MSG_ERRQUEUE
flags = MSG_ERRQUEUE;
#else
Expand Down
3 changes: 2 additions & 1 deletion tdutils/td/utils/port/detail/EventFdBsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ void EventFdBsd::release() {
}

void EventFdBsd::acquire() {
sync_with_poll(out_);
out_.get_poll_info().add_flags(PollFlags::Read());
while (can_read(out_)) {
while (can_read_local(out_)) {
uint8 value[1024];
auto result = out_.read(MutableSlice(value, sizeof(value)));
if (result.is_error()) {
Expand Down
2 changes: 1 addition & 1 deletion tdutils/td/utils/port/detail/EventFdLinux.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void EventFdLinux::release() {
}

void EventFdLinux::acquire() {
impl_->info.get_flags();
impl_->info.sync_with_poll();
SCOPE_EXIT {
// Clear flags without EAGAIN and EWOULDBLOCK
// Looks like it is safe thing to do with eventfd
Expand Down
22 changes: 15 additions & 7 deletions tdutils/td/utils/port/detail/PollableFd.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ class PollableFdInfo : private ListNode {
void clear_flags(PollFlags flags) {
flags_.clear_flags(flags);
}
PollFlags get_flags() const {
//PollFlags get_flags() const {
//return flags_.read_flags();
//}
PollFlags sync_with_poll() const {
return flags_.read_flags();
}
PollFlags get_flags_local() const {
Expand Down Expand Up @@ -208,18 +211,23 @@ inline const NativeFd &PollableFd::native_fd() const {
}

template <class FdT>
bool can_read(const FdT &fd) {
return fd.get_poll_info().get_flags().can_read() || fd.get_poll_info().get_flags().has_pending_error();
void sync_with_poll(const FdT &fd) {
fd.get_poll_info().sync_with_poll();
}

template <class FdT>
bool can_write(const FdT &fd) {
return fd.get_poll_info().get_flags().can_write();
bool can_read_local(const FdT &fd) {
return fd.get_poll_info().get_flags_local().can_read() || fd.get_poll_info().get_flags_local().has_pending_error();
}

template <class FdT>
bool can_close(const FdT &fd) {
return fd.get_poll_info().get_flags().can_close();
bool can_write_local(const FdT &fd) {
return fd.get_poll_info().get_flags_local().can_write();
}

template <class FdT>
bool can_close_local(const FdT &fd) {
return fd.get_poll_info().get_flags_local().can_close();
}

} // namespace td
3 changes: 2 additions & 1 deletion tdutils/test/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ static void test_to_double() {

TEST(Misc, to_double) {
test_to_double();
const char *locale_name = (std::setlocale(LC_ALL, "fr-FR") == nullptr ? "" : "fr-FR");
const char *locale_name = (std::setlocale(LC_ALL, "fr-FR") == nullptr ? "C" : "fr-FR");
LOG(ERROR) << locale_name;
std::locale new_locale(locale_name);
auto host_locale = std::locale::global(new_locale);
test_to_double();
Expand Down
2 changes: 1 addition & 1 deletion test/http.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ TEST(Http, aes_file_encryption) {
fd.set_input_writer(&input_writer);

fd.get_poll_info().add_flags(PollFlags::Read());
while (can_read(fd)) {
while (can_read_local(fd)) {
fd.flush_read(4096).ensure();
source.wakeup();
}
Expand Down

0 comments on commit 38ef3a7

Please sign in to comment.