Skip to content

Commit

Permalink
Problem: poller_t adds an abstraction layer on zmq_poller_*
Browse files Browse the repository at this point in the history
Solution: extract base_poller_t from poller_t, which provides a direct mapping of zmq_poller_* to C++ only
  • Loading branch information
sigiesec committed May 11, 2018
1 parent cdef8bc commit bf47be0
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 56 deletions.
22 changes: 12 additions & 10 deletions tests/poller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ TEST(poller, poll_basic)
message_received = true;
};
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_TRUE(message_received);
}

Expand Down Expand Up @@ -237,13 +237,13 @@ TEST(poller, client_server)
// client sends message
ASSERT_NO_THROW(s.client.send(send_msg));

ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLIN);

// Re-add server socket with pollout flag
ASSERT_NO_THROW(poller.remove(s.server));
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN | ZMQ_POLLOUT, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLOUT);
}

Expand Down Expand Up @@ -335,7 +335,7 @@ TEST(poller, poll_client_server)

// Modify server socket with pollout flag
ASSERT_NO_THROW(poller.modify(s.server, ZMQ_POLLIN | ZMQ_POLLOUT));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT);
}

Expand All @@ -356,8 +356,8 @@ TEST(poller, wait_one_return)
ASSERT_NO_THROW(s.client.send("Hi"));

// wait for message and verify events
int result = poller.wait(std::chrono::milliseconds{500});
ASSERT_EQ(count, result);
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1u, count);
}

TEST(poller, wait_on_move_constructed_poller)
Expand Down Expand Up @@ -401,14 +401,14 @@ TEST(poller, received_on_move_construced_poller)
// client sends message
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received
a.wait(std::chrono::milliseconds{500});
ASSERT_EQ(1, a.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1u, count);
// Move construct poller b
zmq::poller_t b{std::move(a)};
// client sends message again
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received
b.wait(std::chrono::milliseconds{500});
ASSERT_EQ(1, b.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(2u, count);
}

Expand All @@ -424,12 +424,14 @@ TEST(poller, remove_from_handler)

// Setup poller
zmq::poller_t poller;
int count = 0;
for (auto i = 0; i < ITER_NO; ++i) {
ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) {
ASSERT_EQ(events, ZMQ_POLLIN);
poller.remove(setup_list[ITER_NO-i-1].server);
ASSERT_EQ(ITER_NO-i-1, poller.size());
}));
++count;
}
ASSERT_EQ(ITER_NO, poller.size());
// Clients send messages
Expand All @@ -444,8 +446,8 @@ TEST(poller, remove_from_handler)
}

// Fire all handlers in one wait
int count = poller.wait (std::chrono::milliseconds{-1});
ASSERT_EQ(count, ITER_NO);
ASSERT_EQ(ITER_NO, poller.wait (std::chrono::milliseconds{-1}));
ASSERT_EQ(ITER_NO, count);
}

#endif
133 changes: 87 additions & 46 deletions zmq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ namespace zmq
class socket_t
{
friend class monitor_t;
friend class poller_t;
public:
inline socket_t(context_t& context_, int type_)
{
Expand Down Expand Up @@ -1019,6 +1018,67 @@ namespace zmq
};

#if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)
template <typename T = void>
class base_poller_t
{
public:
void add (zmq::socket_t &socket, short events, T *user_data)
{
if (0 != zmq_poller_add (poller_ptr.get (), static_cast<void*>(socket), user_data, events))
{
throw error_t ();
}
}

void remove (zmq::socket_t &socket)
{
if (0 != zmq_poller_remove (poller_ptr.get (), static_cast<void*>(socket)))
{
throw error_t ();
}
}

void modify (zmq::socket_t &socket, short events)
{
if (0 != zmq_poller_modify (poller_ptr.get (), static_cast<void*>(socket), events))
{
throw error_t ();
}
}

int wait_all (std::vector<zmq_poller_event_t> &poller_events, const std::chrono::microseconds timeout)
{
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0)
return rc;

#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;

throw error_t ();
}
private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
};

class poller_t
{
public:
Expand All @@ -1035,33 +1095,35 @@ namespace zmq

void add (zmq::socket_t &socket, short events, handler_t handler)
{
auto it = std::end (handlers);
auto inserted = false;
std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared<handler_t> (std::move (handler)));
if (0 == zmq_poller_add (poller_ptr.get (), socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) {
need_rebuild = true;
return;
auto it = decltype (handlers)::iterator {};
auto inserted = bool {};
std::tie(it, inserted) = handlers.emplace (static_cast<void*>(socket), std::make_shared<handler_t> (std::move (handler)));
try
{
base_poller.add (socket, events, inserted && *(it->second) ? it->second.get() : nullptr);
need_rebuild |= inserted;
}
catch (const zmq::error_t&)
{
// rollback
if (inserted)
{
handlers.erase (static_cast<void*>(socket));
}
throw;
}
// rollback
if (inserted)
handlers.erase (socket.ptr);
throw error_t ();
}

void remove (zmq::socket_t &socket)
{
if (0 == zmq_poller_remove (poller_ptr.get (), socket.ptr)) {
handlers.erase (socket.ptr);
need_rebuild = true;
return;
}
throw error_t ();
base_poller.remove (socket);
handlers.erase (static_cast<void*>(socket));
need_rebuild = true;
}

void modify (zmq::socket_t &socket, short events)
{
if (0 != zmq_poller_modify (poller_ptr.get (), socket.ptr, events))
throw error_t ();
base_poller.modify (socket, events);
}

int wait (std::chrono::milliseconds timeout)
Expand All @@ -1077,25 +1139,15 @@ namespace zmq
}
need_rebuild = false;
}
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0) {
std::for_each (poller_events.begin (), poller_events.begin () + rc,
const int count = base_poller.wait_all (poller_events, timeout);
if (count != 0) {
std::for_each (poller_events.begin (), poller_events.begin () + count,
[](zmq_poller_event_t& event) {
if (event.user_data != NULL)
(*reinterpret_cast<handler_t*> (event.user_data)) (event.events);
});
return rc;
}
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;

throw error_t ();
return count;
}

bool empty () const
Expand All @@ -1109,20 +1161,9 @@ namespace zmq
}

private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
bool need_rebuild {false};

base_poller_t<handler_t> base_poller {};
std::unordered_map<void*, std::shared_ptr<handler_t>> handlers {};
std::vector<zmq_poller_event_t> poller_events {};
std::vector<std::shared_ptr<handler_t>> poller_handlers {};
Expand Down

0 comments on commit bf47be0

Please sign in to comment.