Skip to content

Commit

Permalink
Add startTls parameter to SslContextBuilder
Browse files Browse the repository at this point in the history
Motivation:

There is an incoherence in terms of API when one wants to use
startTls: without startTls one can use the SslContextBuilder's
method newHandler, but with startTls, the developper is forced
to call directly the SslHandler constructor.

Modifications:

Introduce startTls as a SslContextBuilder parameter as well as a
member in SslContext (and thus Jdk and OpenSsl implementations!).
Always use this information to call the SslHandler constructor.
Use false by default, in particular in deprecated constructors of
the SSL implementations.
The client Context use false by default

Results:

Fixes netty#5170 and more generally homogenise the API so that
everything can be done via SslContextBuilder.
  • Loading branch information
victornoel authored and normanmaurer committed Sep 6, 2016
1 parent b604a22 commit 8566fd1
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory tru
trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), true,
ciphers, cipherFilter, apn, ClientAuth.NONE);
ciphers, cipherFilter, apn, ClientAuth.NONE, false);
}

JdkSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
Expand All @@ -260,7 +260,7 @@ public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory tru
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) throws SSLException {
super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, sessionCacheSize, sessionTimeout), true,
ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE);
ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE, false);
}

private static SSLContext newSSLContext(X509Certificate[] trustCertCollection,
Expand Down
7 changes: 4 additions & 3 deletions handler/src/main/java/io/netty/handler/ssl/JdkSslContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private static void addIfSupported(Set<String> supported, List<String> enabled,
public JdkSslContext(SSLContext sslContext, boolean isClient,
ClientAuth clientAuth) {
this(sslContext, isClient, null, IdentityCipherSuiteFilter.INSTANCE,
JdkDefaultApplicationProtocolNegotiator.INSTANCE, clientAuth);
JdkDefaultApplicationProtocolNegotiator.INSTANCE, clientAuth, false);
}

/**
Expand All @@ -169,11 +169,12 @@ public JdkSslContext(SSLContext sslContext, boolean isClient,
public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers,
CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
ClientAuth clientAuth) {
this(sslContext, isClient, ciphers, cipherFilter, toNegotiator(apn, !isClient), clientAuth);
this(sslContext, isClient, ciphers, cipherFilter, toNegotiator(apn, !isClient), clientAuth, false);
}

JdkSslContext(SSLContext sslContext, boolean isClient, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
JdkApplicationProtocolNegotiator apn, ClientAuth clientAuth) {
JdkApplicationProtocolNegotiator apn, ClientAuth clientAuth, boolean startTls) {
super(startTls);
this.apn = checkNotNull(apn, "apn");
this.clientAuth = checkNotNull(clientAuth, "clientAuth");
cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory tru
super(newSSLContext(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false,
ciphers, cipherFilter, apn, ClientAuth.NONE);
ciphers, cipherFilter, apn, ClientAuth.NONE, false);
}

JdkSslServerContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword,
KeyManagerFactory keyManagerFactory, Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout,
ClientAuth clientAuth) throws SSLException {
ClientAuth clientAuth, boolean startTls) throws SSLException {
super(newSSLContext(trustCertCollection, trustManagerFactory, keyCertChain, key,
keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout), false,
ciphers, cipherFilter, toNegotiator(apn, true), clientAuth);
ciphers, cipherFilter, toNegotiator(apn, true), clientAuth, startTls);
}

private static SSLContext newSSLContext(X509Certificate[] trustCertCollection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public OpenSslClientContext(File trustCertCollectionFile, TrustManagerFactory tr
long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain,
ClientAuth.NONE);
ClientAuth.NONE, false);
boolean success = false;
try {
sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@
public abstract class OpenSslContext extends ReferenceCountedOpenSslContext {
OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apnCfg,
long sessionCacheSize, long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth)
ClientAuth clientAuth, boolean startTls)
throws SSLException {
super(ciphers, cipherFilter, apnCfg, sessionCacheSize, sessionTimeout, mode, keyCertChain,
clientAuth, false);
clientAuth, startTls, false);
}

OpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize,
long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, mode, keyCertChain, clientAuth, false);
ClientAuth clientAuth, boolean startTls) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, mode, keyCertChain, clientAuth, startTls,
false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,26 +323,26 @@ public OpenSslServerContext(
this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory,
toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter,
apn, sessionCacheSize, sessionTimeout, ClientAuth.NONE);
apn, sessionCacheSize, sessionTimeout, ClientAuth.NONE, false);
}

OpenSslServerContext(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException {
this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers,
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth);
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, startTls);
}

@SuppressWarnings("deprecation")
private OpenSslServerContext(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain,
clientAuth);
clientAuth, startTls);
// Create a new SSL_CTX and configure it.
boolean success = false;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public final class ReferenceCountedOpenSslClientContext extends ReferenceCounted
long sessionCacheSize, long sessionTimeout)
throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_CLIENT, keyCertChain,
ClientAuth.NONE, true);
ClientAuth.NONE, false, true);
boolean success = false;
try {
sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,18 @@ public String run() {

ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
ApplicationProtocolConfig apnCfg, long sessionCacheSize, long sessionTimeout,
int mode, Certificate[] keyCertChain, ClientAuth clientAuth, boolean leakDetection)
throws SSLException {
int mode, Certificate[] keyCertChain, ClientAuth clientAuth, boolean startTls,
boolean leakDetection) throws SSLException {
this(ciphers, cipherFilter, toNegotiator(apnCfg), sessionCacheSize, sessionTimeout, mode, keyCertChain,
clientAuth, leakDetection);
clientAuth, startTls, leakDetection);
}

ReferenceCountedOpenSslContext(Iterable<String> ciphers, CipherSuiteFilter cipherFilter,
OpenSslApplicationProtocolNegotiator apn, long sessionCacheSize,
long sessionTimeout, int mode, Certificate[] keyCertChain,
ClientAuth clientAuth, boolean leakDetection) throws SSLException {
ClientAuth clientAuth, boolean startTls, boolean leakDetection) throws SSLException {
super(startTls);

OpenSsl.ensureAvailability();

if (mode != SSL.SSL_MODE_SERVER && mode != SSL.SSL_MODE_CLIENT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ public final class ReferenceCountedOpenSslServerContext extends ReferenceCounted
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException {
this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers,
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth);
cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, startTls);
}

private ReferenceCountedOpenSslServerContext(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException {
super(ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, SSL.SSL_MODE_SERVER, keyCertChain,
clientAuth, true);
clientAuth, startTls, true);
// Create a new SSL_CTX and configure it.
boolean success = false;
try {
Expand Down
12 changes: 7 additions & 5 deletions handler/src/main/java/io/netty/handler/ssl/SniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,18 @@ private void onSslContext(ChannelHandlerContext ctx, String hostname, SslContext
* It's also possible for the hostname argument to be {@code null}.
*/
protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception {
SSLEngine sslEngine = null;
SslHandler sslHandler = null;
try {
sslEngine = sslContext.newEngine(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), SslContext.newHandler(sslEngine));
sslEngine = null;
sslHandler = sslContext.newHandler(ctx.alloc());
ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler);
sslHandler = null;
} finally {
// Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
// transferred to the SslHandler.
// See https://github.com/netty/netty/issues/5678
ReferenceCountUtil.safeRelease(sslEngine);
if (sslHandler != null) {
ReferenceCountUtil.safeRelease(sslHandler.engine());
}
}
}

Expand Down
31 changes: 19 additions & 12 deletions handler/src/main/java/io/netty/handler/ssl/SslContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public abstract class SslContext {
}
}

private final boolean startTls;

/**
* Returns the default server-side implementation provider currently in use.
*
Expand Down Expand Up @@ -382,7 +384,7 @@ public static SslContext newServerContext(
toX509Certificates(keyCertChainFile),
toPrivateKey(keyFile, keyPassword),
keyPassword, keyManagerFactory, ciphers, cipherFilter, apn,
sessionCacheSize, sessionTimeout, ClientAuth.NONE);
sessionCacheSize, sessionTimeout, ClientAuth.NONE, false);
} catch (Exception e) {
if (e instanceof SSLException) {
throw (SSLException) e;
Expand All @@ -396,7 +398,7 @@ static SslContext newServerContextInternal(
X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory,
X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory,
Iterable<String> ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn,
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth) throws SSLException {
long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, boolean startTls) throws SSLException {

if (provider == null) {
provider = defaultServerProvider();
Expand All @@ -407,17 +409,17 @@ static SslContext newServerContextInternal(
return new JdkSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth);
clientAuth, startTls);
case OPENSSL:
return new OpenSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth);
clientAuth, startTls);
case OPENSSL_REFCNT:
return new ReferenceCountedOpenSslServerContext(
trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword,
keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout,
clientAuth);
clientAuth, startTls);
default:
throw new Error(provider.toString());
}
Expand Down Expand Up @@ -774,10 +776,19 @@ static ApplicationProtocolConfig toApplicationProtocolConfig(Iterable<String> ne
return apn;
}

/**
* Creates a new instance (startTls set to <code>false</code>).
*/
protected SslContext() {
this(false);
}

/**
* Creates a new instance.
*/
protected SslContext() { }
protected SslContext(boolean startTls) {
this.startTls = startTls;
}

/**
* Returns {@code true} if and only if this context is for server-side.
Expand Down Expand Up @@ -852,7 +863,7 @@ public final List<String> nextProtocols() {
* @return a new {@link SslHandler}
*/
public final SslHandler newHandler(ByteBufAllocator alloc) {
return newHandler(newEngine(alloc));
return new SslHandler(newEngine(alloc), startTls);
}

/**
Expand All @@ -866,11 +877,7 @@ public final SslHandler newHandler(ByteBufAllocator alloc) {
* @return a new {@link SslHandler}
*/
public final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort) {
return newHandler(newEngine(alloc, peerHost, peerPort));
}

static SslHandler newHandler(SSLEngine engine) {
return new SslHandler(engine);
return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public static SslContextBuilder forServer(KeyManagerFactory keyManagerFactory) {
private long sessionCacheSize;
private long sessionTimeout;
private ClientAuth clientAuth = ClientAuth.NONE;
private boolean startTls;

private SslContextBuilder(boolean forServer) {
this.forServer = forServer;
Expand Down Expand Up @@ -383,6 +384,14 @@ public SslContextBuilder clientAuth(ClientAuth clientAuth) {
return this;
}

/**
* {@code true} if the first write request shouldn't be encrypted.
*/
public SslContextBuilder startTls(boolean startTls) {
this.startTls = startTls;
return this;
}

/**
* Create new {@code SslContext} instance with configured settings.
* <p>If {@link #sslProvider(SslProvider)} is set to {@link SslProvider#OPENSSL_REFCNT} then the caller is
Expand All @@ -392,7 +401,7 @@ public SslContext build() throws SSLException {
if (forServer) {
return SslContext.newServerContextInternal(provider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth);
ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth, startTls);
} else {
return SslContext.newClientContextInternal(provider, trustCertCollection,
trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory,
Expand Down

0 comments on commit 8566fd1

Please sign in to comment.