Skip to content

Commit

Permalink
Correctly handle ipv6 mapped ipv4 addresses when using recvmmsg (nett…
Browse files Browse the repository at this point in the history
…y#9541)


Motivation:

394a1b3 introduced the possibility to use recvmmsg(...) but did not correctly handle ipv6 mapped ip4 addresses to make it consistent with other transports.

Modifications:

- Correctly handle ipv6 mapped ipv4 addresses by only copy over the relevant bytes
- Small improvement on how to detect ipv6 mapped ipv4 addresses by using memcmp and not byte by byte compare
- Adjust test to cover this bug

Result:

Correctly handle ipv6 mapped ipv4 addresses
  • Loading branch information
normanmaurer authored Sep 6, 2019
1 parent 768a825 commit 6fc7c58
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.junit.Test;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.NotYetConnectedException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -157,14 +158,19 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio
}
});

final CountDownLatch latch = new CountDownLatch(count);
sc = setupServerChannel(sb, bytes, latch, false);
final SocketAddress sender;
if (bindClient) {
cc = cb.bind(newSocketAddress()).sync().channel();
sender = cc.localAddress();
} else {
cb.option(ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION, true);
cc = cb.register().sync().channel();
sender = null;
}

final CountDownLatch latch = new CountDownLatch(count);
sc = setupServerChannel(sb, bytes, sender, latch, false);

InetSocketAddress addr = (InetSocketAddress) sc.localAddress();
for (int i = 0; i < count; i++) {
switch (wrapType) {
Expand Down Expand Up @@ -231,8 +237,10 @@ public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws E
DatagramChannel cc = null;
try {
final CountDownLatch latch = new CountDownLatch(count);
sc = setupServerChannel(sb, bytes, latch, true);
cc = (DatagramChannel) cb.connect(sc.localAddress()).sync().channel();
cc = (DatagramChannel) cb.bind(newSocketAddress()).sync().channel();
sc = setupServerChannel(sb, bytes, cc.localAddress(), latch, true);

cc.connect(sc.localAddress()).syncUninterruptibly();

for (int i = 0; i < count; i++) {
switch (wrapType) {
Expand Down Expand Up @@ -275,14 +283,21 @@ public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws E
}

@SuppressWarnings("deprecation")
private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final CountDownLatch latch, final boolean echo)
private Channel setupServerChannel(Bootstrap sb, final byte[] bytes, final SocketAddress sender,
final CountDownLatch latch, final boolean echo)
throws Throwable {
sb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(new SimpleChannelInboundHandler<DatagramPacket>() {
@Override
public void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) throws Exception {
if (sender == null) {
assertNotNull(msg.sender());
} else {
assertEquals(sender, msg.sender());
}

ByteBuf buf = msg.content();
assertEquals(bytes.length, buf.readableBytes());
for (int i = 0; i < bytes.length; i++) {
Expand Down
18 changes: 17 additions & 1 deletion transport-native-epoll/src/main/c/netty_epoll_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct mmsghdr {

// Those are initialized in the init(...) method and cached for performance reasons
static jfieldID packetAddrFieldId = NULL;
static jfieldID packetAddrLenFieldId = NULL;
static jfieldID packetScopeIdFieldId = NULL;
static jfieldID packetPortFieldId = NULL;
static jfieldID packetMemoryAddressFieldId = NULL;
Expand Down Expand Up @@ -386,12 +387,20 @@ static jint netty_epoll_native_recvmmsg0(JNIEnv* env, jclass clazz, jint fd, jbo
struct sockaddr_in* ipaddr = (struct sockaddr_in*) addr;

(*env)->SetByteArrayRegion(env, address, 0, 4, (jbyte*) &ipaddr->sin_addr.s_addr);
(*env)->SetIntField(env, packet, packetAddrLenFieldId, 4);
(*env)->SetIntField(env, packet, packetScopeIdFieldId, 0);
(*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ipaddr->sin_port));
} else {
int addrLen = netty_unix_socket_ipAddressLength(addr);
struct sockaddr_in6* ip6addr = (struct sockaddr_in6*) addr;

(*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr);
if (addrLen == 4) {
// IPV4 mapped IPV6 address
(*env)->SetByteArrayRegion(env, address, 12, 4, (jbyte*) &ip6addr->sin6_addr.s6_addr);
} else {
(*env)->SetByteArrayRegion(env, address, 0, 16, (jbyte*) &ip6addr->sin6_addr.s6_addr);
}
(*env)->SetIntField(env, packet, packetAddrLenFieldId, addrLen);
(*env)->SetIntField(env, packet, packetScopeIdFieldId, ip6addr->sin6_scope_id);
(*env)->SetIntField(env, packet, packetPortFieldId, ntohs(ip6addr->sin6_port));
}
Expand Down Expand Up @@ -631,6 +640,11 @@ static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix
netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.addr");
goto error;
}
packetAddrLenFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "addrLen", "I");
if (packetAddrLenFieldId == NULL) {
netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.addrLen");
goto error;
}
packetScopeIdFieldId = (*env)->GetFieldID(env, nativeDatagramPacketCls, "scopeId", "I");
if (packetScopeIdFieldId == NULL) {
netty_unix_errors_throwRuntimeException(env, "failed to get field ID: NativeDatagramPacket.scopeId");
Expand Down Expand Up @@ -675,6 +689,7 @@ static jint netty_epoll_native_JNI_OnLoad(JNIEnv* env, const char* packagePrefix
netty_epoll_linuxsocket_JNI_OnUnLoad(env);
}
packetAddrFieldId = NULL;
packetAddrLenFieldId = NULL;
packetScopeIdFieldId = NULL;
packetPortFieldId = NULL;
packetMemoryAddressFieldId = NULL;
Expand All @@ -692,6 +707,7 @@ static void netty_epoll_native_JNI_OnUnLoad(JNIEnv* env) {
netty_epoll_linuxsocket_JNI_OnUnLoad(env);

packetAddrFieldId = NULL;
packetAddrLenFieldId = NULL;
packetScopeIdFieldId = NULL;
packetPortFieldId = NULL;
packetMemoryAddressFieldId = NULL;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ final class NativeDatagramPacketArray implements ChannelOutboundBuffer.MessagePr
// We share one IovArray for all NativeDatagramPackets to reduce memory overhead. This will allow us to write
// up to IOV_MAX iovec across all messages in one sendmmsg(...) call.
private final IovArray iovArray = new IovArray();

// temporary array to copy the ipv4 part of ipv6-mapped-ipv4 addresses and then create a Inet4Address out of it.
private final byte[] ipv4Bytes = new byte[4];

private int count;

NativeDatagramPacketArray() {
Expand Down Expand Up @@ -110,13 +114,14 @@ void release() {
* Used to pass needed data to JNI.
*/
@SuppressWarnings("unused")
static final class NativeDatagramPacket {
final class NativeDatagramPacket {

// This is the actual struct iovec*
private long memoryAddress;
private int count;

private final byte[] addr = new byte[16];

private int addrLen;
private int scopeId;
private int port;
Expand Down Expand Up @@ -145,10 +150,11 @@ private void init(long memoryAddress, int count, InetSocketAddress recipient) {

DatagramPacket newDatagramPacket(ByteBuf buffer, InetSocketAddress localAddress) throws UnknownHostException {
final InetAddress address;
if (scopeId != 0) {
address = Inet6Address.getByAddress(null, addr, scopeId);
if (addrLen == ipv4Bytes.length) {
System.arraycopy(addr, 0, ipv4Bytes, 0, addrLen);
address = InetAddress.getByAddress(ipv4Bytes);
} else {
address = InetAddress.getByAddress(addr);
address = Inet6Address.getByAddress(null, addr, scopeId);
}
return new DatagramPacket(buffer.writerIndex(count),
localAddress, new InetSocketAddress(address, port));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.junit.Test;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
Expand Down Expand Up @@ -82,6 +83,9 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio
// Nothing will be sent.
}
});
cc = cb.bind(newSocketAddress()).sync().channel();
final SocketAddress ccAddress = cc.localAddress();

final AtomicReference<Throwable> errorRef = new AtomicReference<Throwable>();
final byte[] bytes = new byte[packetSize];
PlatformDependent.threadLocalRandom().nextBytes(bytes);
Expand All @@ -98,6 +102,8 @@ public void channelReadComplete(ChannelHandlerContext ctx) {

@Override
protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) {
assertEquals(ccAddress, msg.sender());

assertEquals(bytes.length, msg.content().readableBytes());
byte[] receivedBytes = new byte[bytes.length];
msg.content().readBytes(receivedBytes);
Expand All @@ -115,7 +121,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {

sb.option(ChannelOption.AUTO_READ, false);
sc = sb.bind(newSocketAddress()).sync().channel();
cc = cb.bind(newSocketAddress()).sync().channel();

InetSocketAddress addr = (InetSocketAddress) sc.localAddress();

Expand Down
13 changes: 6 additions & 7 deletions transport-native-unix-common/src/main/c/netty_unix_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ static jclass inetSocketAddressClass = NULL;
static int socketType = AF_INET;
static const unsigned char wildcardAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };
static const unsigned char ipv4MappedWildcardAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 };
static const unsigned char ipv4MappedAddress[] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff };

// Optional external methods
extern int accept4(int sockFd, struct sockaddr* addr, socklen_t* addrlen, int flags) __attribute__((weak)) __attribute__((weak_import));
Expand All @@ -67,15 +68,13 @@ static int nettyNonBlockingSocket(int domain, int type, int protocol) {
#endif
}

static int ipAddressLength(const struct sockaddr_storage* addr) {
int netty_unix_socket_ipAddressLength(const struct sockaddr_storage* addr) {
if (addr->ss_family == AF_INET) {
return 4;
}
struct sockaddr_in6* s = (struct sockaddr_in6*) addr;
if (s->sin6_addr.s6_addr[11] == 0xff && s->sin6_addr.s6_addr[10] == 0xff &&
s->sin6_addr.s6_addr[9] == 0x00 && s->sin6_addr.s6_addr[8] == 0x00 && s->sin6_addr.s6_addr[7] == 0x00 && s->sin6_addr.s6_addr[6] == 0x00 && s->sin6_addr.s6_addr[5] == 0x00 &&
s->sin6_addr.s6_addr[4] == 0x00 && s->sin6_addr.s6_addr[3] == 0x00 && s->sin6_addr.s6_addr[2] == 0x00 && s->sin6_addr.s6_addr[1] == 0x00 && s->sin6_addr.s6_addr[0] == 0x00) {
// IPv4-mapped-on-IPv6
if (memcmp(s->sin6_addr.s6_addr, ipv4MappedAddress, 12) == 0) {
// IPv4-mapped-on-IPv6
return 4;
}
return 16;
Expand All @@ -84,7 +83,7 @@ static int ipAddressLength(const struct sockaddr_storage* addr) {
static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_storage* addr, int len, jobject local) {
int port;
int scopeId;
int ipLength = ipAddressLength(addr);
int ipLength = netty_unix_socket_ipAddressLength(addr);
jbyteArray addressBytes = (*env)->NewByteArray(env, ipLength);

if (addr->ss_family == AF_INET) {
Expand Down Expand Up @@ -114,7 +113,7 @@ static jobject createDatagramSocketAddress(JNIEnv* env, const struct sockaddr_st
}

static jsize addressLength(const struct sockaddr_storage* addr) {
int len = ipAddressLength(addr);
int len = netty_unix_socket_ipAddressLength(addr);
if (len == 4) {
// Only encode port into it
return len + 4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
int netty_unix_socket_initSockaddr(JNIEnv* env, jboolean ipv6, jbyteArray address, jint scopeId, jint jport, const struct sockaddr_storage* addr, socklen_t* addrSize);
int netty_unix_socket_getOption(JNIEnv* env, jint fd, int level, int optname, void* optval, socklen_t optlen);
int netty_unix_socket_setOption(JNIEnv* env, jint fd, int level, int optname, const void* optval, socklen_t len);
int netty_unix_socket_ipAddressLength(const struct sockaddr_storage* addr);

// These method is sometimes needed if you want to special handle some errno value before throwing an exception.
int netty_unix_socket_getOption0(jint fd, int level, int optname, void* optval, socklen_t optlen);
Expand Down

0 comments on commit 6fc7c58

Please sign in to comment.