Skip to content

Commit

Permalink
Add EventLoopRule and EventLoopGroupRule (line#1304)
Browse files Browse the repository at this point in the history
Motivation:

It would be useful to have JUnit `TestRule`s for Netty `EventLoop` and
`EventLoopGroup`.

Modifications:

- Add `EventLoopRule` and `EventLoopGroupRule`
- Update our tests to leverate the rule.
- Fix a bug where `TransportType` rejects `EpollEventLoop`

Result:

- Easier to prepare an event loop in a test case
  • Loading branch information
trustin authored Jul 25, 2018
1 parent 3d4e662 commit f894e91
Show file tree
Hide file tree
Showing 15 changed files with 363 additions and 126 deletions.
19 changes: 19 additions & 0 deletions core/src/main/java/com/linecorp/armeria/internal/ChannelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,28 @@

import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;

public final class ChannelUtil {

private static final Class<? extends EventLoopGroup> EPOLL_EVENT_LOOP_CLASS;

static {
try {
//noinspection unchecked
EPOLL_EVENT_LOOP_CLASS = (Class<? extends EventLoopGroup>)
Class.forName("io.netty.channel.epoll.EpollEventLoop", false,
EpollEventLoopGroup.class.getClassLoader());
} catch (Exception e) {
throw new IllegalStateException("failed to locate EpollEventLoop class", e);
}
}

public static Class<? extends EventLoopGroup> epollEventLoopClass() {
return EPOLL_EVENT_LOOP_CLASS;
}

public static CompletableFuture<Void> close(Iterable<? extends Channel> channels) {
final List<Channel> channelsCopy = ImmutableList.copyOf(channels);
if (channelsCopy.isEmpty()) {
Expand Down
54 changes: 33 additions & 21 deletions core/src/main/java/com/linecorp/armeria/internal/TransportType.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package com.linecorp.armeria.internal;

import java.util.Set;
import java.util.concurrent.ThreadFactory;
import java.util.function.BiFunction;
import java.util.function.Function;

import javax.annotation.Nullable;

import com.google.common.base.Ascii;
import com.google.common.collect.ImmutableSet;

import com.linecorp.armeria.common.Flags;

Expand All @@ -30,6 +34,7 @@
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoop;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.SocketChannel;
Expand All @@ -41,27 +46,29 @@
* Native transport types.
*/
public enum TransportType {

NIO(NioServerSocketChannel.class, NioSocketChannel.class, NioDatagramChannel.class,
NioEventLoopGroup.class, NioEventLoopGroup::new),
NioEventLoopGroup::new, NioEventLoopGroup.class, NioEventLoop.class),

EPOLL(EpollServerSocketChannel.class, EpollSocketChannel.class, EpollDatagramChannel.class,
EpollEventLoopGroup.class, EpollEventLoopGroup::new);
EpollEventLoopGroup::new, EpollEventLoopGroup.class, ChannelUtil.epollEventLoopClass());

private final Class<? extends ServerChannel> serverChannelClass;
private final Class<? extends SocketChannel> socketChannelClass;
private final Class<? extends DatagramChannel> datagramClass;
private final Class<? extends EventLoopGroup> eventLoopGroupClass;
private final Set<Class<? extends EventLoopGroup>> eventLoopGroupClasses;
private final BiFunction<Integer, ThreadFactory, ? extends EventLoopGroup> eventLoopGroupConstructor;

@SafeVarargs
TransportType(Class<? extends ServerChannel> serverChannelClass,
Class<? extends SocketChannel> socketChannelClass,
Class<? extends DatagramChannel> datagramClass,
Class<? extends EventLoopGroup> eventLoopGroupClass,
BiFunction<Integer, ThreadFactory, ? extends EventLoopGroup> eventLoopGroupConstructor) {
BiFunction<Integer, ThreadFactory, ? extends EventLoopGroup> eventLoopGroupConstructor,
Class<? extends EventLoopGroup>... eventLoopGroupClasses) {
this.serverChannelClass = serverChannelClass;
this.socketChannelClass = socketChannelClass;
this.datagramClass = datagramClass;
this.eventLoopGroupClass = eventLoopGroupClass;
this.eventLoopGroupClasses = ImmutableSet.copyOf(eventLoopGroupClasses);
this.eventLoopGroupConstructor = eventLoopGroupConstructor;
}

Expand Down Expand Up @@ -96,24 +103,14 @@ public static TransportType detectTransportType() {
* Returns the available {@link SocketChannel} class for {@code eventLoopGroup}.
*/
public static Class<? extends SocketChannel> socketChannelType(EventLoopGroup eventLoopGroup) {
for (TransportType type : values()) {
if (type.eventLoopGroupClass.isAssignableFrom(eventLoopGroup.getClass())) {
return type.socketChannelClass;
}
}
throw unsupportedEventLoopType(eventLoopGroup);
return find(eventLoopGroup).socketChannelClass;
}

/**
* Returns the available {@link DatagramChannel} class for {@code eventLoopGroup}.
*/
public static Class<? extends DatagramChannel> datagramChannelType(EventLoopGroup eventLoopGroup) {
for (TransportType type : values()) {
if (type.eventLoopGroupClass.isAssignableFrom(eventLoopGroup.getClass())) {
return type.datagramClass;
}
}
throw unsupportedEventLoopType(eventLoopGroup);
return find(eventLoopGroup).datagramClass;
}

/**
Expand All @@ -131,12 +128,27 @@ public static boolean isSupported(EventLoop eventLoop) {
* Returns whether the specified {@link EventLoopGroup} supports any {@link TransportType}.
*/
public static boolean isSupported(EventLoopGroup eventLoopGroup) {
return findOrNull(eventLoopGroup) != null;
}

private static TransportType find(EventLoopGroup eventLoopGroup) {
final TransportType found = findOrNull(eventLoopGroup);
if (found == null) {
throw unsupportedEventLoopType(eventLoopGroup);
}
return found;
}

@Nullable
private static TransportType findOrNull(EventLoopGroup eventLoopGroup) {
for (TransportType type : values()) {
if (type.eventLoopGroupClass.isAssignableFrom(eventLoopGroup.getClass())) {
return true;
for (Class<? extends EventLoopGroup> eventLoopGroupClass : type.eventLoopGroupClasses) {
if (eventLoopGroupClass.isAssignableFrom(eventLoopGroup.getClass())) {
return type;
}
}
}
return false;
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.common.EventLoopRule;
import com.linecorp.armeria.testing.server.ServerRule;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;

public class HttpClientPipeliningTest {

// Server-side configuration
Expand Down Expand Up @@ -86,22 +84,22 @@ protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req) throws
};

// Client-side configuration
private static EventLoopGroup eventLoopGroup;
@ClassRule
public static final EventLoopRule eventLoopGroup = new EventLoopRule();
private static ClientFactory factoryWithPipelining;
private static ClientFactory factoryWithoutPipelining;

@BeforeClass
public static void initClientFactory() {
// Ensure only a single event loop is used so that there's only one connection pool.
// Note: Each event loop has its own connection pool.
eventLoopGroup = new NioEventLoopGroup(1);
factoryWithPipelining = new ClientFactoryBuilder()
.workerGroup(eventLoopGroup, false)
.workerGroup(eventLoopGroup.get(), false)
.useHttp1Pipelining(true)
.build();

factoryWithoutPipelining = new ClientFactoryBuilder()
.workerGroup(eventLoopGroup, false)
.workerGroup(eventLoopGroup.get(), false)
.useHttp1Pipelining(false)
.build();
}
Expand All @@ -111,7 +109,6 @@ public static void destroyClientFactory() {
ForkJoinPool.commonPool().execute(() -> {
factoryWithPipelining.close();
factoryWithoutPipelining.close();
eventLoopGroup.shutdownGracefully();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import java.util.concurrent.TimeUnit;

import org.junit.AfterClass;
import org.junit.ClassRule;
import org.junit.Test;

import com.linecorp.armeria.client.Client;
Expand All @@ -37,18 +37,12 @@
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpResponseWriter;
import com.linecorp.armeria.common.stream.NoopSubscriber;

import io.netty.channel.DefaultEventLoop;
import io.netty.channel.EventLoop;
import com.linecorp.armeria.testing.common.EventLoopRule;

public class ConcurrencyLimitingHttpClientTest {

private static final EventLoop eventLoop = new DefaultEventLoop();

@AfterClass
public static void destroy() {
eventLoop.shutdownGracefully();
}
@ClassRule
public static final EventLoopRule eventLoop = new EventLoopRule();

/**
* Tests the request pattern that does not exceed maxConcurrency.
Expand Down Expand Up @@ -231,7 +225,7 @@ public void testUnlimitedRequestWithFaultyDelegate() throws Exception {

private static ClientRequestContext newContext() {
final ClientRequestContext ctx = mock(ClientRequestContext.class);
when(ctx.eventLoop()).thenReturn(eventLoop);
when(ctx.eventLoop()).thenReturn(eventLoop.get());
return ctx;
}

Expand All @@ -247,6 +241,6 @@ private static void closeAndDrain(HttpResponseWriter actualRes, HttpResponse def
}

private static void waitForEventLoop() {
eventLoop.submit(() -> { /* no-op */ }).syncUninterruptibly();
eventLoop.get().submit(() -> { /* no-op */ }).syncUninterruptibly();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -53,12 +53,12 @@
import com.linecorp.armeria.common.logging.RequestLogBuilder;
import com.linecorp.armeria.common.metric.NoopMeterRegistry;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.testing.common.EventLoopRule;

import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.EventLoop;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.EventExecutor;
Expand All @@ -69,12 +69,8 @@

public class RequestContextTest {

private static final EventLoop eventLoop = new DefaultEventLoop();

@AfterClass
public static void stopEventLoop() {
eventLoop.shutdownGracefully();
}
@ClassRule
public static final EventLoopRule eventLoop = new EventLoopRule();

@Rule
public MockitoRule mocks = MockitoJUnit.rule();
Expand All @@ -89,7 +85,7 @@ public static void stopEventLoop() {

@Test
public void contextAwareEventExecutor() throws Exception {
when(channel.eventLoop()).thenReturn(eventLoop);
when(channel.eventLoop()).thenReturn(eventLoop.get());
final RequestContext context = createContext();
final Set<Integer> callbacksCalled = Collections.newSetFromMap(new ConcurrentHashMap<>());
final EventExecutor executor = context.contextAwareEventLoop();
Expand Down Expand Up @@ -125,7 +121,7 @@ public void contextAwareEventExecutor() throws Exception {
progressivePromise.addListener(f -> checkCallback(18, context, callbacksCalled, latch));
progressivePromise.setSuccess("success");
latch.await();
eventLoop.shutdownGracefully().sync();
eventLoop.get().shutdownGracefully().sync();
assertThat(callbacksCalled).containsExactlyElementsOf(IntStream.rangeClosed(1, 18).boxed()::iterator);
}

Expand Down Expand Up @@ -308,7 +304,7 @@ public void contextPropagationDifferentContextAlreadySet() {

try (SafeCloseable ignored = context2.push()) {
thrown.expect(IllegalStateException.class);
context.makeContextAware((Runnable) () -> Assert.fail()).run();
context.makeContextAware((Runnable) Assert::fail).run();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void onSubscribeBeforeOnComplete() throws Exception {
// Repeat to increase the chance of reproduction.
for (int i = 0; i < 8192; i++) {
final StreamMessageAndWriter<Integer> stream = newStreamWriter(TEN_INTEGERS);
eventLoop().execute(stream::close);
eventLoop.get().execute(stream::close);
stream.subscribe(new Subscriber<Object>() {
@Override
public void onSubscribe(Subscription s) {
Expand All @@ -82,7 +82,7 @@ public void onError(Throwable t) {
public void onComplete() {
queue.add("onComplete");
}
}, eventLoop());
}, eventLoop.get());

assertThat(queue.poll(5, TimeUnit.SECONDS)).isEqualTo("onSubscribe");
assertThat(queue.poll(5, TimeUnit.SECONDS)).isEqualTo("onComplete");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,45 +25,30 @@
import java.util.List;
import java.util.stream.IntStream;

import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.google.common.collect.ImmutableList;

import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.testing.common.EventLoopRule;
import com.linecorp.armeria.unsafe.ByteBufHttpData;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.EventLoop;

@SuppressWarnings("unchecked") // Allow using the same tests for writers and non-writers
public abstract class AbstractStreamMessageTest {

static final List<Integer> TEN_INTEGERS = IntStream.range(0, 10).boxed().collect(toImmutableList());

private static EventLoop eventLoop;

@BeforeClass
public static void startEventLoop() {
eventLoop = new DefaultEventLoop();
}

@AfterClass
public static void stopEventLoop() {
eventLoop.shutdownGracefully().syncUninterruptibly();
}

EventLoop eventLoop() {
return eventLoop;
}
@ClassRule
public static final EventLoopRule eventLoop = new EventLoopRule();

abstract <T> StreamMessage<T> newStream(List<T> inputs);

Expand Down Expand Up @@ -136,7 +121,7 @@ public void onNext(Integer value) {
public void flowControlled_writeThenDemandThenProcess_eventLoop() throws Exception {
final StreamMessage<Integer> stream = newStream(streamValues());
writeTenIntegers(stream);
eventLoop().submit(
eventLoop.get().submit(
() ->
stream.subscribe(new ResultCollectingSubscriber() {
private Subscription subscription;
Expand All @@ -152,7 +137,7 @@ public void onNext(Integer value) {
subscription.request(1);
super.onNext(value);
}
}, eventLoop())).syncUninterruptibly();
}, eventLoop.get())).syncUninterruptibly();
assertSuccess();
}

Expand Down
Loading

0 comments on commit f894e91

Please sign in to comment.