Skip to content

Commit

Permalink
Fix race condition in HystrixSampleSseServlet for response writes
Browse files Browse the repository at this point in the history
  • Loading branch information
erichhsun authored and ericsun-insikt committed Jan 17, 2018
1 parent 6825138 commit da05b3b
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ public abstract class HystrixSampleSseServlet extends HttpServlet {

private final int pausePollerThreadDelayInMs;

/* response is not thread-safe */
private final Object responseWriteLock = new Object();

/* Set to true upon shutdown, so it's OK to be shared among all SampleSseServlets */
private static volatile boolean isDestroyed = false;

Expand Down Expand Up @@ -147,12 +150,15 @@ public void onError(Throwable e) {
public void onNext(String sampleDataAsString) {
if (sampleDataAsString != null) {
try {
writer.print("data: " + sampleDataAsString + "\n\n");
// explicitly check for client disconnect - PrintWriter does not throw exceptions
if (writer.checkError()) {
moreDataWillBeSent.set(false);
// avoid concurrent writes with ping
synchronized (responseWriteLock) {
writer.print("data: " + sampleDataAsString + "\n\n");
// explicitly check for client disconnect - PrintWriter does not throw exceptions
if (writer.checkError()) {
moreDataWillBeSent.set(false);
}
writer.flush();
}
writer.flush();
} catch (Exception ex) {
moreDataWillBeSent.set(false);
}
Expand All @@ -164,12 +170,16 @@ public void onNext(String sampleDataAsString) {
try {
Thread.sleep(pausePollerThreadDelayInMs);
//in case stream has not started emitting yet, catch any clients which connect/disconnect before emits start
writer.print("ping: \n\n");
// explicitly check for client disconnect - PrintWriter does not throw exceptions
if (writer.checkError()) {
moreDataWillBeSent.set(false);

// avoid concurrent writes with sample
synchronized (responseWriteLock) {
writer.print("ping: \n\n");
// explicitly check for client disconnect - PrintWriter does not throw exceptions
if (writer.checkError()) {
moreDataWillBeSent.set(false);
}
writer.flush();
}
writer.flush();
} catch (Exception ex) {
moreDataWillBeSent.set(false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/**
* Copyright 2016 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.netflix.hystrix.contrib.sample.stream;

import com.netflix.config.DynamicIntProperty;
import com.netflix.config.DynamicPropertyFactory;
import com.netflix.hystrix.config.HystrixConfiguration;
import com.netflix.hystrix.config.HystrixConfigurationStream;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import rx.Observable;
import rx.Subscriber;
import rx.functions.Func1;
import rx.schedulers.Schedulers;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.when;

public class HystrixSampleSseServletTest {

private static final String INTERJECTED_CHARACTER = "a";

@Mock HttpServletRequest mockReq;
@Mock HttpServletResponse mockResp;
@Mock HystrixConfiguration mockConfig;
@Mock PrintWriter mockPrintWriter;

TestHystrixConfigSseServlet servlet;

@Before
public void init() {
MockitoAnnotations.initMocks(this);
}

@After
public void tearDown() {
servlet.destroy();
servlet.shutdown();
}

@Test
public void testNoConcurrentResponseWrites() throws IOException, InterruptedException {
final Observable<HystrixConfiguration> limitedOnNexts = Observable.create(new Observable.OnSubscribe<HystrixConfiguration>() {
@Override
public void call(Subscriber<? super HystrixConfiguration> subscriber) {
try {
for (int i = 0; i < 500; i++) {
Thread.sleep(10);
subscriber.onNext(mockConfig);
}

} catch (InterruptedException ex) {
ex.printStackTrace();
} catch (Exception e) {
subscriber.onCompleted();
}
}
}).subscribeOn(Schedulers.computation());

servlet = new TestHystrixConfigSseServlet(limitedOnNexts, 1);
try {
servlet.init();
} catch (ServletException ex) {

}

final StringBuilder buffer = new StringBuilder();

when(mockReq.getParameter("delay")).thenReturn("100");
when(mockResp.getWriter()).thenReturn(mockPrintWriter);
Mockito.doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
String written = (String) invocation.getArguments()[0];
if (written.contains("ping")) {
buffer.append(INTERJECTED_CHARACTER);
} else {
// slow down the append to increase chances to interleave
for (int i = 0; i < written.length(); i++) {
Thread.sleep(5);
buffer.append(written.charAt(i));
}
}
return null;
}
}).when(mockPrintWriter).print(Mockito.anyString());

Runnable simulateClient = new Runnable() {
@Override
public void run() {
try {
servlet.doGet(mockReq, mockResp);
} catch (ServletException ex) {
fail(ex.getMessage());
} catch (IOException ex) {
fail(ex.getMessage());
}
}
};

Thread t = new Thread(simulateClient);
t.start();

try {
Thread.sleep(1000);
System.out.println(System.currentTimeMillis() + " Woke up from sleep : " + Thread.currentThread().getName());
} catch (InterruptedException ex) {
fail(ex.getMessage());
}

Pattern pattern = Pattern.compile("\\{[" + INTERJECTED_CHARACTER + "]+\\}");
boolean hasInterleaved = pattern.matcher(buffer).find();
assertFalse(hasInterleaved);
}

private static class TestHystrixConfigSseServlet extends HystrixSampleSseServlet {

private static AtomicInteger concurrentConnections = new AtomicInteger(0);
private static DynamicIntProperty maxConcurrentConnections = DynamicPropertyFactory.getInstance().getIntProperty("hystrix.config.stream.maxConcurrentConnections", 5);

public TestHystrixConfigSseServlet() {
this(HystrixConfigurationStream.getInstance().observe(), DEFAULT_PAUSE_POLLER_THREAD_DELAY_IN_MS);
}

TestHystrixConfigSseServlet(Observable<HystrixConfiguration> sampleStream, int pausePollerThreadDelayInMs) {
super(sampleStream.map(new Func1<HystrixConfiguration, String>() {
@Override
public String call(HystrixConfiguration hystrixConfiguration) {
return "{}";
}
}), pausePollerThreadDelayInMs);
}

@Override
protected int getMaxNumberConcurrentConnectionsAllowed() {
return maxConcurrentConnections.get();
}

@Override
protected int getNumberCurrentConnections() {
return concurrentConnections.get();
}

@Override
protected int incrementAndGetCurrentConcurrentConnections() {
return concurrentConnections.incrementAndGet();
}

@Override
protected void decrementCurrentConcurrentConnections() {
concurrentConnections.decrementAndGet();
}
}
}

0 comments on commit da05b3b

Please sign in to comment.