Skip to content

Commit

Permalink
ChannelInitializer may be invoked multiple times when used with custo…
Browse files Browse the repository at this point in the history
…m EventExecutor. (netty#8620)

Motivation:

The ChannelInitializer may be invoked multipled times when used with a custom EventExecutor as removal operation may be done asynchronously. We need to guard against this.

Modifications:

- Change Map to Set which is more correct in terms of how we use it.
- Ensure we only modify the internal Set when the handler was removed yet
- Add unit test.

Result:

Fixes netty#8616.
  • Loading branch information
normanmaurer authored Dec 5, 2018
1 parent 6739755 commit 8331248
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 7 deletions.
32 changes: 25 additions & 7 deletions transport/src/main/java/io/netty/channel/ChannelInitializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.util.concurrent.ConcurrentMap;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was
Expand Down Expand Up @@ -53,9 +54,10 @@
public abstract class ChannelInitializer<C extends Channel> extends ChannelInboundHandlerAdapter {

private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class);
// We use a ConcurrentMap as a ChannelInitializer is usually shared between all Channels in a Bootstrap /
// We use a Set as a ChannelInitializer is usually shared between all Channels in a Bootstrap /
// ServerBootstrap. This way we can reduce the memory usage compared to use Attributes.
private final ConcurrentMap<ChannelHandlerContext, Boolean> initMap = PlatformDependent.newConcurrentHashMap();
private final Set<ChannelHandlerContext> initMap = Collections.newSetFromMap(
new ConcurrentHashMap<ChannelHandlerContext, Boolean>());

/**
* This method will be called once the {@link Channel} was registered. After the method returns this instance
Expand Down Expand Up @@ -108,9 +110,14 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
}
}

@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
initMap.remove(ctx);
}

@SuppressWarnings("unchecked")
private boolean initChannel(ChannelHandlerContext ctx) throws Exception {
if (initMap.putIfAbsent(ctx, Boolean.TRUE) == null) { // Guard against re-entrance.
if (initMap.add(ctx)) { // Guard against re-entrance.
try {
initChannel((C) ctx.channel());
} catch (Throwable cause) {
Expand All @@ -125,14 +132,25 @@ private boolean initChannel(ChannelHandlerContext ctx) throws Exception {
return false;
}

private void remove(ChannelHandlerContext ctx) {
private void remove(final ChannelHandlerContext ctx) {
try {
ChannelPipeline pipeline = ctx.pipeline();
if (pipeline.context(this) != null) {
pipeline.remove(this);
}
} finally {
initMap.remove(ctx);
// The removal may happen in an async fashion if the EventExecutor we use does something funky.
if (ctx.isRemoved()) {
initMap.remove(ctx);
} else {
// Ensure we always remove from the Map in all cases to not produce a memory leak.
ctx.channel().closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
initMap.remove(ctx);
}
});
}
}
}
}
126 changes: 126 additions & 0 deletions transport/src/test/java/io/netty/channel/ChannelInitializerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -35,6 +39,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;

public class ChannelInitializerTest {
Expand Down Expand Up @@ -249,6 +254,127 @@ private void testChannelRegisteredEventPropagation(ChannelInitializer<LocalChann
}
}

@SuppressWarnings("deprecation")
@Test(timeout = 10000)
public void testChannelInitializerEventExecutor() throws Throwable {
final AtomicInteger invokeCount = new AtomicInteger();
final AtomicInteger completeCount = new AtomicInteger();
final AtomicReference<Throwable> errorRef = new AtomicReference<Throwable>();
LocalAddress addr = new LocalAddress("test");

final EventExecutor executor = new DefaultEventLoop() {
private final ScheduledExecutorService execService = Executors.newSingleThreadScheduledExecutor();

@Override
public void shutdown() {
execService.shutdown();
}

@Override
public boolean inEventLoop(Thread thread) {
// Always return false which will ensure we always call execute(...)
return false;
}

@Override
public boolean isShuttingDown() {
return false;
}

@Override
public Future<?> shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) {
throw new IllegalStateException();
}

@Override
public Future<?> terminationFuture() {
throw new IllegalStateException();
}

@Override
public boolean isShutdown() {
return execService.isShutdown();
}

@Override
public boolean isTerminated() {
return execService.isTerminated();
}

@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return execService.awaitTermination(timeout, unit);
}

@Override
public void execute(Runnable command) {
execService.execute(command);
}
};

ServerBootstrap serverBootstrap = new ServerBootstrap()
.channel(LocalServerChannel.class)
.group(group)
.localAddress(addr)
.childHandler(new ChannelInitializer<LocalChannel>() {
@Override
protected void initChannel(LocalChannel ch) {
ch.pipeline().addLast(executor, new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) {
invokeCount.incrementAndGet();
ChannelHandlerContext ctx = ch.pipeline().context(this);
assertNotNull(ctx);
ch.pipeline().addAfter(ctx.executor(),
ctx.name(), null, new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
// just drop on the floor.
}
});
completeCount.incrementAndGet();
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
errorRef.set(cause);
}
});
}
});

Channel server = serverBootstrap.bind().sync().channel();

Bootstrap clientBootstrap = new Bootstrap()
.channel(LocalChannel.class)
.group(group)
.remoteAddress(addr)
.handler(new ChannelInboundHandlerAdapter());

Channel client = clientBootstrap.connect().sync().channel();
client.writeAndFlush("Hello World").sync();

client.close().sync();
server.close().sync();

client.closeFuture().sync();
server.closeFuture().sync();

// Give some time to execute everything that was submitted before.
Thread.sleep(1000);

executor.shutdown();
assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS));

assertEquals(invokeCount.get(), 1);
assertEquals(invokeCount.get(), completeCount.get());

Throwable cause = errorRef.get();
if (cause != null) {
throw cause;
}
}

private static void closeChannel(Channel c) {
if (c != null) {
c.close().syncUninterruptibly();
Expand Down

0 comments on commit 8331248

Please sign in to comment.