Skip to content

Commit

Permalink
HTTP/2 Prevent modification of activeStreams while iterating
Browse files Browse the repository at this point in the history
Motivation:
The Http2Connection interface exposes an activeStreams() method which allows direct iteration over the underlying collection. There are a few places that make copies of this collection to avoid modification while iterating, and a few places that do not make copies. The copy operation can be expensive on hot code paths and also we are not consistently iterating over the activeStreams collection.

Modifications:
- The Http2Connection interface should reduce the exposure of the underlying collection and just expose what is necessary for the interface to function.  This is just a means to iterate over the collection.
- The DefaultHttp2Connection should use this new interface and protect it's internal state while iteration is occurring.

Result:
Reduction in surface area of the Http2Connection interface.  Consistent iteration of the set of active streams.  Concurrent modification exceptions are handled in 1 encapsulated spot.
  • Loading branch information
Scottmitch committed Apr 8, 2015
1 parent d5d932a commit 83ce8a9
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,39 @@
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL;
import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE;
import static io.netty.util.internal.ObjectUtil.checkNotNull;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http2.Http2StreamRemovalPolicy.Action;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap;
import io.netty.util.collection.PrimitiveCollections;
import io.netty.util.internal.PlatformDependent;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;

/**
* Simple implementation of {@link Http2Connection}.
*/
public class DefaultHttp2Connection implements Http2Connection {

// Fields accessed by inner classes
final IntObjectMap<Http2Stream> streamMap = new IntObjectHashMap<Http2Stream>();
final ConnectionStream connectionStream = new ConnectionStream();
final DefaultEndpoint<Http2LocalFlowController> localEndpoint;
final DefaultEndpoint<Http2RemoteFlowController> remoteEndpoint;
/**
* We chose a {@link List} over a {@link Set} to avoid allocating an {@link Iterator} objects when iterating over
* the listeners.
*/
private final List<Listener> listeners = new ArrayList<Listener>(4);
private final IntObjectMap<Http2Stream> streamMap = new IntObjectHashMap<Http2Stream>();
private final ConnectionStream connectionStream = new ConnectionStream();
private final Set<Http2Stream> activeStreams = new LinkedHashSet<Http2Stream>();
private final DefaultEndpoint<Http2LocalFlowController> localEndpoint;
private final DefaultEndpoint<Http2RemoteFlowController> remoteEndpoint;
private final Http2StreamRemovalPolicy removalPolicy;
final List<Listener> listeners = new ArrayList<Listener>(4);
final ActiveStreams activeStreams;

/**
* Creates a connection with an immediate stream removal policy.
Expand All @@ -86,7 +86,7 @@ public DefaultHttp2Connection(boolean server) {
* the policy to be used for removal of closed stream.
*/
public DefaultHttp2Connection(boolean server, Http2StreamRemovalPolicy removalPolicy) {
this.removalPolicy = checkNotNull(removalPolicy, "removalPolicy");
activeStreams = new ActiveStreams(listeners, checkNotNull(removalPolicy, "removalPolicy"));
localEndpoint = new DefaultEndpoint<Http2LocalFlowController>(server);
remoteEndpoint = new DefaultEndpoint<Http2RemoteFlowController>(!server);

Expand Down Expand Up @@ -142,8 +142,8 @@ public int numActiveStreams() {
}

@Override
public Set<Http2Stream> activeStreams() {
return Collections.unmodifiableSet(activeStreams);
public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception {
return activeStreams.forEachActiveStream(visitor);
}

@Override
Expand All @@ -162,17 +162,24 @@ public boolean goAwayReceived() {
}

@Override
public void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf debugData) {
public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) {
localEndpoint.lastKnownStream(lastKnownStream);
for (Listener listener : listeners) {
listener.onGoAwayReceived(lastKnownStream, errorCode, debugData);
}

Http2Stream[] streams = new Http2Stream[numActiveStreams()];
for (Http2Stream stream : activeStreams().toArray(streams)) {
if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) {
stream.close();
}
try {
forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && localEndpoint.createdStreamId(stream.id())) {
stream.close();
}
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
}
}

Expand All @@ -182,17 +189,24 @@ public boolean goAwaySent() {
}

@Override
public void goAwaySent(int lastKnownStream, long errorCode, ByteBuf debugData) {
public void goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) {
remoteEndpoint.lastKnownStream(lastKnownStream);
for (Listener listener : listeners) {
listener.onGoAwaySent(lastKnownStream, errorCode, debugData);
}

Http2Stream[] streams = new Http2Stream[numActiveStreams()];
for (Http2Stream stream : activeStreams().toArray(streams)) {
if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) {
stream.close();
}
try {
forEachActiveStream(new StreamVisitor() {
@Override
public boolean visit(Http2Stream stream) {
if (stream.id() > lastKnownStream && remoteEndpoint.createdStreamId(stream.id())) {
stream.close();
}
return true;
}
});
} catch (Http2Exception e) {
PlatformDependent.throwException(e);
}
}

Expand Down Expand Up @@ -384,15 +398,7 @@ public Http2Stream open(boolean halfClosed) throws Http2Exception {
throw streamError(id, PROTOCOL_ERROR, "Attempting to open a stream in an invalid state: " + state);
}

if (activeStreams.add(this)) {
// Update the number of active streams initiated by the endpoint.
createdBy().numActiveStreams++;

// Notify the listeners.
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamActive(this);
}
}
activeStreams.activate(this);
return this;
}

Expand All @@ -404,20 +410,8 @@ public Http2Stream close() {

state = CLOSED;
decrementPrioritizableForTree(1);
if (activeStreams.remove(this)) {
try {
// Update the number of active streams initiated by the endpoint.
createdBy().numActiveStreams--;

// Notify the listeners.
for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamClosed(this);
}
} finally {
// Mark this stream for removal.
removalPolicy.markForRemoval(this);
}
}
activeStreams.deactivate(this);
return this;
}

Expand Down Expand Up @@ -790,8 +784,9 @@ private final class DefaultEndpoint<F extends Http2FlowController> implements En
private int lastKnownStream = -1;
private boolean pushToAllowed = true;
private F flowController;
private int numActiveStreams;
private int maxActiveStreams;
// Fields accessed by inner classes
int numActiveStreams;

DefaultEndpoint(boolean server) {
this.server = server;
Expand Down Expand Up @@ -951,8 +946,8 @@ private void checkNewStreamAllowed(int streamId) throws Http2Exception {
throw new Http2NoMoreStreamIdsException();
}
if (!createdStreamId(streamId)) {
throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection",
streamId, server ? "server" : "client");
throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", streamId,
server ? "server" : "client");
}
// This check must be after all id validated checks, but before the max streams check because it may be
// recoverable to some degree for handling frames which can be sent on closed streams.
Expand All @@ -969,4 +964,116 @@ private boolean isLocal() {
return this == localEndpoint;
}
}

/**
* Default implementation of the {@link ActiveStreams} class.
*/
private static final class ActiveStreams {
/**
* Allows events which would modify {@link #streams} to be queued while iterating over {@link #streams}.
*/
interface Event {
/**
* Trigger the original intention of this event. Expect to modify {@link #streams}.
*/
void process();
}

private final List<Listener> listeners;
private final Http2StreamRemovalPolicy removalPolicy;
private final Queue<Event> pendingEvents = new ArrayDeque<Event>(4);
private final Set<Http2Stream> streams = new LinkedHashSet<Http2Stream>();
private int pendingIterations;

public ActiveStreams(List<Listener> listeners, Http2StreamRemovalPolicy removalPolicy) {
this.listeners = listeners;
this.removalPolicy = removalPolicy;
}

public int size() {
return streams.size();
}

public void activate(final DefaultStream stream) {
if (allowModifications()) {
addToActiveStreams(stream);
} else {
pendingEvents.add(new Event() {
@Override
public void process() {
addToActiveStreams(stream);
}
});
}
}

public void deactivate(final DefaultStream stream) {
if (allowModifications()) {
removeFromActiveStreams(stream);
} else {
pendingEvents.add(new Event() {
@Override
public void process() {
removeFromActiveStreams(stream);
}
});
}
}

public Http2Stream forEachActiveStream(StreamVisitor visitor) throws Http2Exception {
++pendingIterations;
Http2Stream resultStream = null;
try {
for (Http2Stream stream : streams) {
if (!visitor.visit(stream)) {
resultStream = stream;
break;
}
}
return resultStream;
} finally {
--pendingIterations;
if (allowModifications()) {
for (;;) {
Event event = pendingEvents.poll();
if (event == null) {
break;
}
event.process();
}
}
}
}

void addToActiveStreams(DefaultStream stream) {
if (streams.add(stream)) {
// Update the number of active streams initiated by the endpoint.
stream.createdBy().numActiveStreams++;

for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamActive(stream);
}
}
}

void removeFromActiveStreams(DefaultStream stream) {
if (streams.remove(stream)) {
try {
// Update the number of active streams initiated by the endpoint.
stream.createdBy().numActiveStreams--;

for (int i = 0; i < listeners.size(); i++) {
listeners.get(i).onStreamClosed(stream);
}
} finally {
// Mark this stream for removal.
removalPolicy.markForRemoval(stream);
}
}
}

private boolean allowModifications() {
return pendingIterations == 0;
}
}
}
Loading

0 comments on commit 83ce8a9

Please sign in to comment.