Skip to content

Commit

Permalink
Change recv and send to return optional types
Browse files Browse the repository at this point in the history
  • Loading branch information
gummif committed May 10, 2019
1 parent bbba565 commit 88cee88
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 73 deletions.
63 changes: 31 additions & 32 deletions tests/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ TEST_CASE("socket sends and receives const buffer", "[socket]")
const char* str = "Hi";

#ifdef ZMQ_CPP11
CHECK(2 == sender.send(zmq::buffer(str, 2)).size);
CHECK(2 == *sender.send(zmq::buffer(str, 2)));
char buf[2];
const auto res = receiver.recv(zmq::buffer(buf));
CHECK(!res.truncated());
CHECK(2 == res.size);
CHECK(res);
CHECK(!res->truncated());
CHECK(2 == res->size);
#else
CHECK(2 == sender.send(str, 2));
char buf[2];
Expand All @@ -109,11 +110,11 @@ TEST_CASE("socket send none sndmore", "[socket]")

std::vector<char> buf(4);
auto res = s.send(zmq::buffer(buf), zmq::send_flags::sndmore);
CHECK(res.size == buf.size());
CHECK(res.success);
CHECK(res);
CHECK(*res == buf.size());
res = s.send(zmq::buffer(buf));
CHECK(res.size == buf.size());
CHECK(res.success);
CHECK(res);
CHECK(*res == buf.size());
}

TEST_CASE("socket send dontwait", "[socket]")
Expand All @@ -124,17 +125,14 @@ TEST_CASE("socket send dontwait", "[socket]")

std::vector<char> buf(4);
auto res = s.send(zmq::buffer(buf), zmq::send_flags::dontwait);
CHECK(!res.success);
CHECK(res.size == 0);
CHECK(!res);
res = s.send(zmq::buffer(buf),
zmq::send_flags::dontwait | zmq::send_flags::sndmore);
CHECK(!res.success);
CHECK(res.size == 0);
CHECK(!res);

zmq::message_t msg;
auto resm = s.send(msg, zmq::send_flags::dontwait);
CHECK(!resm.success);
CHECK(resm.size == 0);
CHECK(!resm);
CHECK(msg.size() == 0);
}

Expand All @@ -158,23 +156,24 @@ TEST_CASE("socket recv none", "[socket]")

std::vector<char> sbuf(4);
const auto res_send = s2.send(zmq::buffer(sbuf));
CHECK(res_send.success);
CHECK(res_send);
CHECK(res_send.has_value());

std::vector<char> buf(2);
const auto res = s.recv(zmq::buffer(buf));
CHECK(res.success);
CHECK(res.truncated());
CHECK(res.untruncated_size == sbuf.size());
CHECK(res.size == buf.size());
CHECK(res.has_value());
CHECK(res->truncated());
CHECK(res->untruncated_size == sbuf.size());
CHECK(res->size == buf.size());

const auto res_send2 = s2.send(zmq::buffer(sbuf));
CHECK(res_send2.success);
CHECK(res_send2.has_value());
std::vector<char> buf2(10);
const auto res2 = s.recv(zmq::buffer(buf2));
CHECK(res2.success);
CHECK(!res2.truncated());
CHECK(res2.untruncated_size == sbuf.size());
CHECK(res2.size == sbuf.size());
CHECK(res2.has_value());
CHECK(!res2->truncated());
CHECK(res2->untruncated_size == sbuf.size());
CHECK(res2->size == sbuf.size());
}

TEST_CASE("socket send recv message_t", "[socket]")
Expand All @@ -187,15 +186,16 @@ TEST_CASE("socket send recv message_t", "[socket]")

zmq::message_t smsg(size_t{10});
const auto res_send = s2.send(smsg, zmq::send_flags::none);
CHECK(res_send.success);
CHECK(res_send.size == 10);
CHECK(res_send);
CHECK(*res_send == 10);
CHECK(smsg.size() == 0);

zmq::message_t rmsg;
const auto res = s.recv(rmsg);
CHECK(res.success);
CHECK(res.size == 10);
CHECK(rmsg.size() == res.size);
CHECK(res);
CHECK(*res == 10);
CHECK(res.value() == 10);
CHECK(rmsg.size() == *res);
}

TEST_CASE("socket recv dontwait", "[socket]")
Expand All @@ -207,13 +207,12 @@ TEST_CASE("socket recv dontwait", "[socket]")
std::vector<char> buf(4);
constexpr auto flags = zmq::recv_flags::none | zmq::recv_flags::dontwait;
auto res = s.recv(zmq::buffer(buf), flags);
CHECK(!res.success);
CHECK(res.size == 0);
CHECK(!res);

zmq::message_t msg;
auto resm = s.recv(msg, flags);
CHECK(!resm.success);
CHECK(resm.size == 0);
CHECK(!resm);
CHECK_THROWS_AS(resm.value(), const std::exception &);
CHECK(msg.size() == 0);
}

Expand Down
132 changes: 93 additions & 39 deletions zmq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,21 @@
#include <cassert>
#include <cstring>

#ifdef ZMQ_CPP11
#include <array>
#endif
#include <algorithm>
#include <exception>
#include <iomanip>
#include <iterator>
#include <sstream>
#include <string>
#include <vector>
#ifdef ZMQ_CPP11
#include <array>
#include <chrono>
#include <tuple>
#include <memory>
#endif
#ifdef ZMQ_CPP17
#include <optional>
#endif

/* Version macros for compile-time API version detection */
#define CPPZMQ_VERSION_MAJOR 4
Expand All @@ -92,12 +97,6 @@
ZMQ_MAKE_VERSION(CPPZMQ_VERSION_MAJOR, CPPZMQ_VERSION_MINOR, \
CPPZMQ_VERSION_PATCH)

#ifdef ZMQ_CPP11
#include <chrono>
#include <tuple>
#include <memory>
#endif

// Detect whether the compiler supports C++11 rvalue references.
#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) \
&& defined(__GXX_EXPERIMENTAL_CXX0X__))
Expand Down Expand Up @@ -621,23 +620,11 @@ inline void swap(context_t &a, context_t &b) ZMQ_NOTHROW {
}

#ifdef ZMQ_CPP11
struct send_result
{
size_t size; // message size in bytes
bool success;
};

struct recv_result
{
size_t size; // message size in bytes
bool success;
};

struct recv_buffer_result
struct recv_buffer_size
{
size_t size; // number of bytes written to buffer
size_t untruncated_size; // untruncated message size in bytes
bool success;

ZMQ_NODISCARD bool truncated() const noexcept
{
Expand All @@ -647,6 +634,71 @@ struct recv_buffer_result

namespace detail
{

#ifdef ZMQ_CPP17
using send_result_t = std::optional<size_t>;
using recv_result_t = std::optional<size_t>;
using recv_buffer_result_t = std::optional<recv_buffer_size>;
#else
// A C++11 type emulating the most basic
// operations of std::optional for trivial types
template<class T> class trivial_optional
{
public:
static_assert(std::is_trivial<T>::value, "T must be trivial");
using value_type = T;

trivial_optional() = default;
trivial_optional(T value) noexcept : _value(value), _has_value(true) {}

const T *operator->() const noexcept
{
assert(_has_value);
return &_value;
}
T *operator->() noexcept
{
assert(_has_value);
return &_value;
}

const T &operator*() const noexcept
{
assert(_has_value);
return _value;
}
T &operator*() noexcept
{
assert(_has_value);
return _value;
}

T &value()
{
if (!_has_value)
throw std::exception();
return _value;
}
const T &value() const
{
if (!_has_value)
throw std::exception();
return _value;
}

explicit operator bool() const noexcept { return _has_value; }
bool has_value() const noexcept { return _has_value; }

private:
T _value{};
bool _has_value{false};
};

using send_result_t = trivial_optional<size_t>;
using recv_result_t = trivial_optional<size_t>;
using recv_buffer_result_t = trivial_optional<recv_buffer_size>;
#endif

template<class T>
constexpr T enum_bit_or(T a, T b) noexcept
{
Expand Down Expand Up @@ -1111,36 +1163,36 @@ class socket_base
int flags_ = 0) // default until removed
{
#ifdef ZMQ_CPP11
return send(msg_, static_cast<send_flags>(flags_)).success;
return send(msg_, static_cast<send_flags>(flags_)).has_value();
#else
return send(msg_, flags_);
#endif
}
#endif

#ifdef ZMQ_CPP11
send_result send(const_buffer buf, send_flags flags = send_flags::none)
detail::send_result_t send(const_buffer buf, send_flags flags = send_flags::none)
{
const int nbytes =
zmq_send(_handle, buf.data(), buf.size(), static_cast<int>(flags));
if (nbytes >= 0)
return {static_cast<size_t>(nbytes), true};
return static_cast<size_t>(nbytes);
if (zmq_errno() == EAGAIN)
return {size_t{0}, false};
return {};
throw error_t();
}

send_result send(message_t &msg, send_flags flags)
detail::send_result_t send(message_t &msg, send_flags flags)
{
int nbytes = zmq_msg_send(msg.handle(), _handle, static_cast<int>(flags));
if (nbytes >= 0)
return {static_cast<size_t>(nbytes), true};
return static_cast<size_t>(nbytes);
if (zmq_errno() == EAGAIN)
return {size_t{0}, false};
return {};
throw error_t();
}

send_result send(message_t &&msg, send_flags flags)
detail::send_result_t send(message_t &&msg, send_flags flags)
{
return send(msg, flags);
}
Expand Down Expand Up @@ -1177,27 +1229,29 @@ class socket_base
}

#ifdef ZMQ_CPP11
recv_buffer_result recv(mutable_buffer buf, recv_flags flags = recv_flags::none)
detail::recv_buffer_result_t recv(mutable_buffer buf,
recv_flags flags = recv_flags::none)
{
const int nbytes =
zmq_recv(_handle, buf.data(), buf.size(), static_cast<int>(flags));
if (nbytes >= 0)
return {(std::min)(static_cast<size_t>(nbytes), buf.size()),
static_cast<size_t>(nbytes), true};
if (nbytes >= 0) {
return recv_buffer_size{(std::min)(static_cast<size_t>(nbytes), buf.size()),
static_cast<size_t>(nbytes)};
}
if (zmq_errno() == EAGAIN)
return {size_t{0}, size_t{0}, false};
return {};
throw error_t();
}

recv_result recv(message_t &msg, recv_flags flags = recv_flags::none)
detail::recv_result_t recv(message_t &msg, recv_flags flags = recv_flags::none)
{
const int nbytes = zmq_msg_recv(msg.handle(), _handle, static_cast<int>(flags));
if (nbytes >= 0) {
assert(msg.size() == static_cast<size_t>(nbytes));
return {static_cast<size_t>(nbytes), true};
return static_cast<size_t>(nbytes);
}
if (zmq_errno() == EAGAIN)
return {size_t{0}, false};
return {};
throw error_t();
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions zmq_addon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class multipart_t
while (more) {
message_t message;
#ifdef ZMQ_CPP11
if (!socket.recv(message, static_cast<recv_flags>(flags)).success)
if (!socket.recv(message, static_cast<recv_flags>(flags)))
return false;
#else
if (!socket.recv(&message, flags))
Expand All @@ -153,7 +153,7 @@ class multipart_t
more = size() > 0;
#ifdef ZMQ_CPP11
if (!socket.send(message,
static_cast<send_flags>((more ? ZMQ_SNDMORE : 0) | flags)).success)
static_cast<send_flags>((more ? ZMQ_SNDMORE : 0) | flags)))
return false;
#else
if (!socket.send(message, (more ? ZMQ_SNDMORE : 0) | flags))
Expand Down

0 comments on commit 88cee88

Please sign in to comment.