Skip to content

Commit

Permalink
Make call canceling more reliable.
Browse files Browse the repository at this point in the history
We had a bug where the socket-being-connected wasn't being closed when the
application used Call.cancel(). The problem is that the SocketConnector model
assumes the Connection doesn't want a Socket instance until it's fully
connected.

This moves the SocketConnector code back into Connection, removes a lot of
nested try/catch blocks, and assigns a Socket instance as soon as its created.

This also likely fixes some bugs where sockets weren't being closed when
an IOException or RouteException was thrown during connection. Now we always
close at the top level of connect() unless the connection is successful.

square#1779
  • Loading branch information
squarejesse committed Aug 1, 2015
1 parent 1e5f3a9 commit b42e73f
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 334 deletions.
66 changes: 60 additions & 6 deletions okhttp-tests/src/test/java/com/squareup/okhttp/CallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.squareup.okhttp.internal.RecordingOkAuthenticator;
import com.squareup.okhttp.internal.SingleInetAddressNetwork;
import com.squareup.okhttp.internal.SslContextBuilder;
import com.squareup.okhttp.internal.Util;
import com.squareup.okhttp.internal.Version;
import com.squareup.okhttp.internal.io.FileSystem;
import com.squareup.okhttp.internal.io.InMemoryFileSystem;
Expand All @@ -36,7 +37,10 @@
import java.net.CookieManager;
import java.net.HttpCookie;
import java.net.HttpURLConnection;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.net.ServerSocket;
import java.net.UnknownServiceException;
import java.security.cert.Certificate;
import java.util.ArrayList;
Expand All @@ -52,6 +56,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ServerSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
Expand Down Expand Up @@ -89,19 +94,16 @@ public final class CallTest {
private OkHttpClient client = new OkHttpClient();
private RecordingCallback callback = new RecordingCallback();
private TestLogHandler logHandler = new TestLogHandler();
private Cache cache;
private Cache cache = new Cache(new File("/cache/"), Integer.MAX_VALUE, fileSystem);
private ServerSocket nullServer;

@Before public void setUp() throws Exception {
client = new OkHttpClient();
callback = new RecordingCallback();
logHandler = new TestLogHandler();

cache = new Cache(new File("/cache/"), Integer.MAX_VALUE, fileSystem);
logger.addHandler(logHandler);
}

@After public void tearDown() throws Exception {
cache.delete();
Util.closeQuietly(nullServer);
logger.removeHandler(logHandler);
}

Expand Down Expand Up @@ -1469,6 +1471,45 @@ private void postBodyRetransmittedAfterAuthorizationFail(String body) throws Exc
assertEquals(0, server.getRequestCount());
}

@Test public void cancelDuringHttpConnect() throws Exception {
cancelDuringConnect("http");
}

@Test public void cancelDuringHttpsConnect() throws Exception {
cancelDuringConnect("https");
}

/** Cancel a call that's waiting for connect to complete. */
private void cancelDuringConnect(String scheme) throws Exception {
InetSocketAddress socketAddress = startNullServer();

HttpUrl url = new HttpUrl.Builder()
.scheme(scheme)
.host(socketAddress.getHostName())
.port(socketAddress.getPort())
.build();

long cancelDelayMillis = 300L;
Call call = client.newCall(new Request.Builder().url(url).build());
cancelLater(call, cancelDelayMillis);

long startNanos = System.nanoTime();
try {
call.execute();
fail();
} catch (IOException expected) {
}
long elapsedNanos = System.nanoTime() - startNanos;
assertEquals(cancelDelayMillis, TimeUnit.NANOSECONDS.toMillis(elapsedNanos), 100f);
}

private InetSocketAddress startNullServer() throws IOException {
InetSocketAddress address = new InetSocketAddress(InetAddress.getByName("localhost"), 0);
nullServer = ServerSocketFactory.getDefault().createServerSocket();
nullServer.bind(address);
return new InetSocketAddress(address.getAddress(), nullServer.getLocalPort());
}

@Test public void cancelTagImmediatelyAfterEnqueue() throws Exception {
Call call = client.newCall(new Request.Builder()
.url(server.url("/a"))
Expand Down Expand Up @@ -1806,6 +1847,19 @@ private Buffer gzip(String data) throws IOException {
return result;
}

private void cancelLater(final Call call, final long delay) {
new Thread("canceler") {
@Override public void run() {
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
throw new AssertionError();
}
call.cancel();
}
}.start();
}

private static class RecordingSSLSocketFactory extends DelegatingSSLSocketFactory {

private List<SSLSocket> socketsCreated = new ArrayList<>();
Expand Down
236 changes: 207 additions & 29 deletions okhttp/src/main/java/com/squareup/okhttp/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,38 @@
*/
package com.squareup.okhttp;

import com.squareup.okhttp.internal.ConnectionSpecSelector;
import com.squareup.okhttp.internal.Platform;
import com.squareup.okhttp.internal.Util;
import com.squareup.okhttp.internal.framed.FramedConnection;
import com.squareup.okhttp.internal.http.FramedTransport;
import com.squareup.okhttp.internal.http.HttpConnection;
import com.squareup.okhttp.internal.http.HttpEngine;
import com.squareup.okhttp.internal.http.HttpTransport;
import com.squareup.okhttp.internal.http.OkHeaders;
import com.squareup.okhttp.internal.http.RouteException;
import com.squareup.okhttp.internal.http.SocketConnector;
import com.squareup.okhttp.internal.http.FramedTransport;
import com.squareup.okhttp.internal.http.Transport;
import com.squareup.okhttp.internal.framed.FramedConnection;
import com.squareup.okhttp.internal.tls.OkHostnameVerifier;
import java.io.IOException;
import java.net.Proxy;
import java.net.Socket;
import java.net.URL;
import java.net.UnknownServiceException;
import java.security.cert.X509Certificate;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.Source;

import static com.squareup.okhttp.internal.Util.closeQuietly;
import static com.squareup.okhttp.internal.Util.getDefaultPort;
import static com.squareup.okhttp.internal.Util.getEffectivePort;
import static java.net.HttpURLConnection.HTTP_OK;
import static java.net.HttpURLConnection.HTTP_PROXY_AUTH;

/**
* The sockets and streams of an HTTP, HTTPS, or HTTPS+SPDY connection. May be
Expand Down Expand Up @@ -136,40 +154,200 @@ void connect(int connectTimeout, int readTimeout, int writeTimeout, Request requ
List<ConnectionSpec> connectionSpecs, boolean connectionRetryEnabled) throws RouteException {
if (connected) throw new IllegalStateException("already connected");

SocketConnector socketConnector = new SocketConnector(this, pool);
SocketConnector.ConnectedSocket connectedSocket;
RouteException routeException = null;
ConnectionSpecSelector connectionSpecSelector = new ConnectionSpecSelector(connectionSpecs);
Proxy proxy = route.getProxy();
Address address = route.getAddress();

if (route.address.getSslSocketFactory() == null
&& !connectionSpecs.contains(ConnectionSpec.CLEARTEXT)) {
throw new RouteException(new UnknownServiceException(
"CLEARTEXT communication not supported: " + connectionSpecs));
}

while (!connected) {
try {
socket = proxy.type() == Proxy.Type.DIRECT || proxy.type() == Proxy.Type.HTTP
? address.getSocketFactory().createSocket()
: new Socket(proxy);
connectSocket(connectTimeout, readTimeout, writeTimeout, request,
connectionSpecSelector);
connected = true; // Success!
} catch (IOException e) {
Util.closeQuietly(socket);
socket = null;

if (routeException == null) {
routeException = new RouteException(e);
} else {
routeException.addConnectException(e);
}

if (!connectionRetryEnabled || !connectionSpecSelector.connectionFailed(e)) {
throw routeException;
}
}
}
}

/** Does all the work necessary to build a full HTTP or HTTPS connection on a raw socket. */
private void connectSocket(int connectTimeout, int readTimeout, int writeTimeout,
Request request, ConnectionSpecSelector connectionSpecSelector) throws IOException {
socket.setSoTimeout(readTimeout);
Platform.get().connectSocket(socket, route.getSocketAddress(), connectTimeout);

if (route.address.getSslSocketFactory() != null) {
// https:// communication
connectedSocket = socketConnector.connectTls(connectTimeout, readTimeout, writeTimeout,
request, route, connectionSpecs, connectionRetryEnabled);
connectTls(readTimeout, writeTimeout, request, connectionSpecSelector);
}

if (protocol == Protocol.SPDY_3 || protocol == Protocol.HTTP_2) {
socket.setSoTimeout(0); // Framed connection timeouts are set per-stream.
framedConnection = new FramedConnection.Builder(route.address.uriHost, true, socket)
.protocol(protocol).build();
framedConnection.sendConnectionPreface();
} else {
// http:// communication.
if (!connectionSpecs.contains(ConnectionSpec.CLEARTEXT)) {
throw new RouteException(
new UnknownServiceException(
"CLEARTEXT communication not supported: " + connectionSpecs));
}
connectedSocket = socketConnector.connectCleartext(connectTimeout, readTimeout, route);
httpConnection = new HttpConnection(pool, this, socket);
}
}

socket = connectedSocket.socket;
handshake = connectedSocket.handshake;
protocol = connectedSocket.alpnProtocol == null
? Protocol.HTTP_1_1 : connectedSocket.alpnProtocol;
private void connectTls(int readTimeout, int writeTimeout, Request request,
ConnectionSpecSelector connectionSpecSelector) throws IOException {
if (route.requiresTunnel()) {
createTunnel(readTimeout, writeTimeout, request);
}

Address address = route.getAddress();
SSLSocketFactory sslSocketFactory = address.getSslSocketFactory();
boolean success = false;
SSLSocket sslSocket = null;
try {
if (protocol == Protocol.SPDY_3 || protocol == Protocol.HTTP_2) {
socket.setSoTimeout(0); // Framed connection timeouts are set per-stream.
framedConnection = new FramedConnection.Builder(route.address.uriHost, true, socket)
.protocol(protocol).build();
framedConnection.sendConnectionPreface();
} else {
httpConnection = new HttpConnection(pool, this, socket);
// Create the wrapper over the connected socket.
sslSocket = (SSLSocket) sslSocketFactory.createSocket(
socket, address.getUriHost(), address.getUriPort(), true /* autoClose */);

// Configure the socket's ciphers, TLS versions, and extensions.
ConnectionSpec connectionSpec = connectionSpecSelector.configureSecureSocket(sslSocket);
if (connectionSpec.supportsTlsExtensions()) {
Platform.get().configureTlsExtensions(
sslSocket, address.getUriHost(), address.getProtocols());
}
} catch (IOException e) {
throw new RouteException(e);

// Force handshake. This can throw!
sslSocket.startHandshake();
Handshake unverifiedHandshake = Handshake.get(sslSocket.getSession());

// Verify that the socket's certificates are acceptable for the target host.
if (!address.getHostnameVerifier().verify(address.getUriHost(), sslSocket.getSession())) {
X509Certificate cert = (X509Certificate) unverifiedHandshake.peerCertificates().get(0);
throw new SSLPeerUnverifiedException("Hostname " + address.getUriHost() + " not verified:"
+ "\n certificate: " + CertificatePinner.pin(cert)
+ "\n DN: " + cert.getSubjectDN().getName()
+ "\n subjectAltNames: " + OkHostnameVerifier.allSubjectAltNames(cert));
}

// Check that the certificate pinner is satisfied by the certificates presented.
address.getCertificatePinner().check(address.getUriHost(),
unverifiedHandshake.peerCertificates());

// Success! Save the handshake and the ALPN protocol.
String maybeProtocol = connectionSpec.supportsTlsExtensions()
? Platform.get().getSelectedProtocol(sslSocket)
: null;
protocol = maybeProtocol != null
? Protocol.get(maybeProtocol)
: Protocol.HTTP_1_1;
handshake = unverifiedHandshake;
socket = sslSocket;
success = true;
} finally {
if (sslSocket != null) {
Platform.get().afterHandshake(sslSocket);
}
if (!success) {
closeQuietly(sslSocket);
}
}
}

/**
* To make an HTTPS connection over an HTTP proxy, send an unencrypted
* CONNECT request to create the proxy connection. This may need to be
* retried if the proxy requires authorization.
*/
private void createTunnel(int readTimeout, int writeTimeout, Request request) throws IOException {
// Make an SSL Tunnel on the first message pair of each SSL + proxy connection.
Request tunnelRequest = createTunnelRequest(request);
HttpConnection tunnelConnection = new HttpConnection(pool, this, socket);
tunnelConnection.setTimeouts(readTimeout, writeTimeout);
URL url = tunnelRequest.url();
String requestLine = "CONNECT " + url.getHost() + ":" + getEffectivePort(url) + " HTTP/1.1";
while (true) {
tunnelConnection.writeRequest(tunnelRequest.headers(), requestLine);
tunnelConnection.flush();
Response response = tunnelConnection.readResponse().request(tunnelRequest).build();
// The response body from a CONNECT should be empty, but if it is not then we should consume
// it before proceeding.
long contentLength = OkHeaders.contentLength(response);
if (contentLength == -1L) {
contentLength = 0L;
}
Source body = tunnelConnection.newFixedLengthSource(contentLength);
Util.skipAll(body, Integer.MAX_VALUE, TimeUnit.MILLISECONDS);
body.close();

switch (response.code()) {
case HTTP_OK:
// Assume the server won't send a TLS ServerHello until we send a TLS ClientHello. If
// that happens, then we will have buffered bytes that are needed by the SSLSocket!
// This check is imperfect: it doesn't tell us whether a handshake will succeed, just
// that it will almost certainly fail because the proxy has sent unexpected data.
if (tunnelConnection.bufferSize() > 0) {
throw new IOException("TLS tunnel buffered too many bytes!");
}
return;

case HTTP_PROXY_AUTH:
tunnelRequest = OkHeaders.processAuthHeader(
route.getAddress().getAuthenticator(), response, route.getProxy());
if (tunnelRequest != null) continue;
throw new IOException("Failed to authenticate with proxy");

default:
throw new IOException(
"Unexpected response code for CONNECT: " + response.code());
}
}
}

/**
* Returns a request that creates a TLS tunnel via an HTTP proxy, or null if
* no tunnel is necessary. Everything in the tunnel request is sent
* unencrypted to the proxy server, so tunnels include only the minimum set of
* headers. This avoids sending potentially sensitive data like HTTP cookies
* to the proxy unencrypted.
*/
private Request createTunnelRequest(Request request) throws IOException {
String host = request.url().getHost();
int port = getEffectivePort(request.url());
String authority = (port == getDefaultPort("https")) ? host : (host + ":" + port);
Request.Builder result = new Request.Builder()
.url(new URL("https", host, port, "/"))
.header("Host", authority)
.header("Proxy-Connection", "Keep-Alive"); // For HTTP/1.0 proxies like Squid.

// Copy over the User-Agent header if it exists.
String userAgent = request.header("User-Agent");
if (userAgent != null) {
result.header("User-Agent", userAgent);
}
connected = true;

// Copy over the Proxy-Authorization header if it exists.
String proxyAuthorization = request.header("Proxy-Authorization");
if (proxyAuthorization != null) {
result.header("Proxy-Authorization", proxyAuthorization);
}

return result.build();
}

/**
Expand Down
Loading

0 comments on commit b42e73f

Please sign in to comment.