Skip to content

Commit

Permalink
OpenSslEngine wrap may generate bad data if multiple src buffers
Browse files Browse the repository at this point in the history
Motivation:
SSL_write requires a fixed amount of bytes for overhead related to the encryption process for each call. OpenSslEngine#wrap(..) will attempt to encrypt multiple input buffers until MAX_PLAINTEXT_LENGTH are consumed, but the size estimation provided by calculateOutNetBufSize may not leave enough room for each call to SSL_write. If SSL_write is not able to completely write results to the destination buffer it will keep state and attempt to write it later. Netty doesn't account for SSL_write keeping state and assumes all writes will complete synchronously (by attempting to allocate enough space to account for the overhead) and feeds the same data to SSL_write again later which results in corrupted data being generated.

Modifications:
- OpenSslEngine#wrap should only produce a single TLS packet according to the SSLEngine API specificaiton [1].
[1] https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/SSLEngine.html#wrap-java.nio.ByteBuffer:A-int-int-java.nio.ByteBuffer-
- OpenSslEngine#wrap should only consider a single buffer when determining if there is enough space to write, because only a single buffer will ever be consumed.

Result:
OpenSslEngine#wrap will no longer produce corrupted data due to incorrect accounting of space required in the destination buffers.
  • Loading branch information
Scottmitch committed May 8, 2017
1 parent cd80b6c commit 1410899
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ public final SSLEngineResult wrap(
}
}

if (dst.remaining() < calculateOutNetBufSize(srcsLen, endOffset - offset)) {
// Can not hold the maximum packet so we need to tell the caller to use a bigger destination
// buffer.
// we will only produce a single TLS packet, and we don't aggregate src buffers,
// so we always fix the number of buffers to 1 when checking if the dst buffer is large enough.
if (dst.remaining() < calculateOutNetBufSize(srcsLen, 1)) {
return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
}

Expand All @@ -638,9 +638,7 @@ public final SSLEngineResult wrap(
bytesProduced += bioLengthBefore - pendingNow;
bioLengthBefore = pendingNow;

if (bytesConsumed == MAX_PLAINTEXT_LENGTH || bytesProduced == dst.remaining()) {
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
}
return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced);
} else {
int sslError = SSL.getError(ssl, bytesWritten);
if (sslError == SSL.SSL_ERROR_ZERO_RETURN) {
Expand Down
10 changes: 5 additions & 5 deletions handler/src/main/java/io/netty/handler/ssl/SslHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len,
}

@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ReferenceCountedOpenSslEngine.calculateOutNetBufSize(pendingBytes, numComponents);
}
},
Expand Down Expand Up @@ -242,7 +242,7 @@ SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len,
}

@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents);
}
},
Expand All @@ -258,7 +258,7 @@ SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len,
}

@Override
int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents) {
int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents) {
return handler.maxPacketBufferSize;
}
};
Expand All @@ -281,7 +281,7 @@ static SslEngineType forEngine(SSLEngine engine) {
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out)
throws SSLException;

abstract int calculateOutNetBufSize(SslHandler handler, int pendingBytes, int numComponents);
abstract int calculateWrapBufferCapacity(SslHandler handler, int pendingBytes, int numComponents);

// BEGIN Platform-dependent flags

Expand Down Expand Up @@ -1719,7 +1719,7 @@ private ByteBuf allocate(ChannelHandlerContext ctx, int capacity) {
* the specified amount of pending bytes.
*/
private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) {
return allocate(ctx, engineType.calculateOutNetBufSize(this, pendingBytes, numComponents));
return allocate(ctx, engineType.calculateWrapBufferCapacity(this, pendingBytes, numComponents));
}

private final class LazyChannelPromise extends DefaultPromise<Channel> {
Expand Down
110 changes: 64 additions & 46 deletions handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,39 @@

package io.netty.handler.ssl;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;

import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.CodecException;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import io.netty.util.IllegalReferenceCountException;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays;
import org.junit.Test;

import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;

import java.io.File;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
Expand All @@ -76,6 +59,23 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.ManagerFactoryParameters;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assume.assumeTrue;

public class SslHandlerTest {

Expand Down Expand Up @@ -566,7 +566,7 @@ public void operationComplete(Future<Channel> future) {
}
}

@Test(timeout = 30000)
@Test(timeout = 300000)
public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SslProvider[] providers = SslProvider.values();
Expand All @@ -576,15 +576,23 @@ public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite()
for (int j = 0; j < providers.length; ++j) {
SslProvider clientProvider = providers[j];
if (isSupported(clientProvider)) {
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
true, false);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, true);
compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider,
false, false);
}
}
}
}
}

private static void compositeBufSizeEstimationGuaranteesSynchronousWrite(
SslProvider serverProvider, SslProvider clientProvider)
SslProvider serverProvider, SslProvider clientProvider,
final boolean letHandlerCreateServerEngine, final boolean letHandlerCreateClientEngine)
throws CertificateException, SSLException, ExecutionException, InterruptedException {
SelfSignedCertificate ssc = new SelfSignedCertificate();

Expand All @@ -601,33 +609,39 @@ private static void compositeBufSizeEstimationGuaranteesSynchronousWrite(
Channel cc = null;
try {
final Promise<Void> donePromise = group.next().newPromise();
final int expectedBytes = 469 + 1024 + 1024;
// The goal is to provide the SSLEngine with many ByteBuf components to ensure that the overhead for wrap
// is correctly accounted for on each component.
final int numComponents = 150;
// This is the TLS packet size. The goal is to divide the maximum amount of application data that can fit
// into a single TLS packet into many components to ensure the overhead is correctly taken into account.
final int desiredBytes = 16384;
final int singleComponentSize = desiredBytes / numComponents;
final int expectedBytes = numComponents * singleComponentSize;

sc = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
if (letHandlerCreateServerEngine) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslServerCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SslHandshakeCompletionEvent) {
SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt;
if (sslEvt.isSuccess()) {
final ByteBuf input = ctx.alloc().buffer();
input.writeBytes(new byte[expectedBytes]);
CompositeByteBuf content = ctx.alloc().compositeBuffer();
content.addComponent(true, input.readRetainedSlice(469));
content.addComponent(true, input.readRetainedSlice(1024));
content.addComponent(true, input.readRetainedSlice(1024));
ctx.writeAndFlush(content).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
input.release();
}
});
CompositeByteBuf content = ctx.alloc().compositeDirectBuffer(numComponents);
for (int i = 0; i < numComponents; ++i) {
ByteBuf buf = ctx.alloc().directBuffer(singleComponentSize);
buf.writerIndex(buf.writerIndex() + singleComponentSize);
content.addComponent(true, buf);
}
ctx.writeAndFlush(content);
} else {
donePromise.tryFailure(sslEvt.cause());
}
Expand All @@ -654,7 +668,11 @@ public void channelInactive(ChannelHandlerContext ctx) {
.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
if (letHandlerCreateClientEngine) {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc()));
} else {
ch.pipeline().addLast(new SslHandler(sslClientCtx.newEngine(ch.alloc())));
}
ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
private int bytesSeen;
@Override
Expand Down

0 comments on commit 1410899

Please sign in to comment.