Skip to content

Commit

Permalink
fix RpcInvokeContext not thread safe (sofastack#1081)
Browse files Browse the repository at this point in the history
* fix RpcInvokeContext not thread safe
  • Loading branch information
OrezzerO authored Sep 13, 2021
1 parent 47c430c commit f6864bb
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
import com.alipay.sofa.rpc.core.invoke.SofaResponseCallback;
import com.alipay.sofa.rpc.message.ResponseFuture;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
* 基于ThreadLocal的面向业务开发者使用的上下文传递对象
Expand All @@ -37,13 +35,52 @@ public class RpcInvokeContext {
* 线程上下文变量
*/
protected static final ThreadLocal<RpcInvokeContext> LOCAL = new ThreadLocal<RpcInvokeContext>();

/**
* 是否开启上下文透传功能
*
* @since 5.1.2
*/
private static final boolean BAGGAGE_ENABLE = RpcConfigs.getBooleanValue(RpcOptions.INVOKE_BAGGAGE_ENABLE);
/**
* 自定义 header ,用完一次即删
*/
protected HashMap<String, String> customHeader = new HashMap<>();


protected Map<String, String> customHeader = new ConcurrentHashMap<>();
/**
* 用户自定义超时时间,单次调用生效
*/
protected Integer timeout;
/**
* 用户自定义对方地址,单次调用生效
*/
protected String targetURL;
/**
* 用户自定义对方分组
*/
protected String targetGroup;
/**
* 用户自定义Callback,单次调用生效
*/
protected SofaResponseCallback responseCallback;
/**
* The Future.
*/
protected ResponseFuture<?> future;
/**
* 自定义属性
*/
protected Map<String, Object> map = new ConcurrentHashMap<>();
/**
* 请求上的透传数据
*
* @since 5.1.2
*/
protected Map<String, String> requestBaggage = BAGGAGE_ENABLE ? new ConcurrentHashMap<>() : null;
/**
* 响应上的透传数据
*
* @since 5.1.2
*/
protected Map<String, String> responseBaggage = BAGGAGE_ENABLE ? new ConcurrentHashMap<>() : null;

/**
* 得到上下文,没有则初始化
Expand All @@ -59,6 +96,45 @@ public static RpcInvokeContext getContext() {
return context;
}

/**
* 设置上下文
*
* @param context 调用上下文
*/
public static void setContext(RpcInvokeContext context) {
LOCAL.set(RpcInvokeContext.clone(context));
}

private static RpcInvokeContext clone(RpcInvokeContext parent) {
if (parent == null) {
return null;
}
RpcInvokeContext child = new RpcInvokeContext();
//timeout
child.setTimeout(parent.getTimeout());
//targetURL
child.setTargetURL(parent.getTargetURL());
//targetGroup
child.setTargetGroup(parent.getTargetGroup());
//responseCallback
child.setResponseCallback(parent.getResponseCallback());
//future
child.setFuture(parent.getFuture());
//map
child.map.putAll(parent.map);
//customHeader
child.customHeader.putAll(parent.customHeader);

if (BAGGAGE_ENABLE) {
//requestBaggage
child.requestBaggage.putAll(parent.requestBaggage);
//responseBaggage
child.responseBaggage.putAll(parent.responseBaggage);
}

return child;
}

/**
* 查看上下文
*
Expand All @@ -75,22 +151,6 @@ public static void removeContext() {
LOCAL.remove();
}

/**
* 设置上下文
*
* @param context 调用上下文
*/
public static void setContext(RpcInvokeContext context) {
LOCAL.set(context);
}

/**
* 是否开启上下文透传功能
*
* @since 5.1.2
*/
private static final boolean BAGGAGE_ENABLE = RpcConfigs.getBooleanValue(RpcOptions.INVOKE_BAGGAGE_ENABLE);

/**
* 是否启用RPC透传功能
*
Expand All @@ -100,50 +160,6 @@ public static boolean isBaggageEnable() {
return BAGGAGE_ENABLE;
}

/**
* 用户自定义超时时间,单次调用生效
*/
protected Integer timeout;

/**
* 用户自定义对方地址,单次调用生效
*/
protected String targetURL;

/**
* 用户自定义对方分组
*/
protected String targetGroup;

/**
* 用户自定义Callback,单次调用生效
*/
protected SofaResponseCallback responseCallback;

/**
* The Future.
*/
protected ResponseFuture<?> future;

/**
* 自定义属性
*/
protected ConcurrentMap<String, Object> map = new ConcurrentHashMap<String, Object>();

/**
* 请求上的透传数据
*
* @since 5.1.2
*/
protected Map<String, String> requestBaggage = BAGGAGE_ENABLE ? new HashMap<String, String>() : null;

/**
* 响应上的透传数据
*
* @since 5.1.2
*/
protected Map<String, String> responseBaggage = BAGGAGE_ENABLE ? new HashMap<String, String>() : null;

/**
* 得到调用级别超时时间
*
Expand Down Expand Up @@ -251,7 +267,7 @@ public Map<String, String> getAllRequestBaggage() {

/**
* 设置全部请求透传数据
*
*
* @param requestBaggage 请求透传数据
*/
public void putAllRequestBaggage(Map<String, String> requestBaggage) {
Expand Down Expand Up @@ -309,7 +325,7 @@ public Map<String, String> getAllResponseBaggage() {

/**
* 设置全部响应透传数据
*
*
* @param responseBaggage 响应透传数据
*/
public void putAllResponseBaggage(Map<String, String> responseBaggage) {
Expand Down Expand Up @@ -402,15 +418,15 @@ public RpcInvokeContext setFuture(ResponseFuture<?> future) {


public Map<String, String> getCustomHeader() {
return new HashMap<>(customHeader);
return customHeader;
}

/**
* 设置请求头,与 RequestBaggage 相比
* 1. 不受 enable baggage 开关影响,始终生效
* 2. 仅对一次调用生效,调用完成之会被清空
*
* @param key header key
* @param key header key
* @param value header value
*/
public void addCustomHeader(String key, String value) {
Expand All @@ -433,6 +449,7 @@ public String toString() {
sb.append(", map=").append(map);
sb.append(", requestBaggage=").append(requestBaggage);
sb.append(", responseBaggage=").append(responseBaggage);
sb.append(", customHeader=").append(customHeader);
sb.append('}');
return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

/**
*
*
* @author <a href="mailto:[email protected]">GengZhang</a>
*/
public class RpcInvokeContextTest {
Expand Down Expand Up @@ -50,7 +53,7 @@ public void getContext() throws Exception {
context = new RpcInvokeContext();
RpcInvokeContext.setContext(context);
Assert.assertTrue(RpcInvokeContext.getContext() != null);
Assert.assertEquals(RpcInvokeContext.getContext(), context);
Assert.assertNotEquals(RpcInvokeContext.getContext(), context);

RpcInvokeContext.removeContext();
Assert.assertTrue(RpcInvokeContext.peekContext() == null);
Expand All @@ -67,4 +70,99 @@ public void getContext() throws Exception {
public void peekContext() throws Exception {
}

@Test
public void testThreadSafe() {
RpcInvokeContext context = RpcInvokeContext.getContext();
CountDownLatch countDownLatch = new CountDownLatch(2);
Runnable runnable = new Runnable() {
@Override
public void run() {
countDownLatch.countDown();
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
long start = System.currentTimeMillis();
RpcInvokeContext.setContext(context);
long now = System.currentTimeMillis();
int i = 0;
while (now - start < 100) {
now = System.currentTimeMillis();
i++;
RpcInvokeContext.getContext().addCustomHeader(i + "", i + "");
try {
new HashMap<>().putAll(RpcInvokeContext.getContext().getCustomHeader());
} catch (Exception e) {
System.out.println(i);
throw e;
}
}

}
};

new Thread(runnable).start();
runnable.run();
}

@Test
public void testSetContext() {
RpcInvokeContext context = new RpcInvokeContext();
context.setTargetGroup("target");
context.setTargetURL("url");
context.setTimeout(111);
context.addCustomHeader("A", "B");
context.put("C", "D");
RpcInvokeContext.setContext(context);
Assert.assertEquals(context.getTargetGroup(), RpcInvokeContext.getContext().getTargetGroup());
Assert.assertEquals(context.getTargetURL(), RpcInvokeContext.getContext().getTargetURL());
Assert.assertEquals(context.getTimeout(), RpcInvokeContext.getContext().getTimeout());
Assert.assertEquals("B", RpcInvokeContext.getContext().getCustomHeader().get("A"));
Assert.assertEquals("D", RpcInvokeContext.getContext().get("C"));
Assert.assertTrue(context != RpcInvokeContext.getContext());
RpcInvokeContext.removeContext();
}

@Test
public void testConcurrentModify() throws InterruptedException {
for (int i = 0; i < 10; i++) {
RpcInvokeContext.getContext().put("" + i, "" + i);
}

CountDownLatch countDownLatch = new CountDownLatch(2);
RpcInvokeContext mainContext = RpcInvokeContext.getContext();
AtomicReference<RuntimeException> exceptionHolder = new AtomicReference<>();
new Thread(() -> {
countDownLatch.countDown();
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
long start = System.currentTimeMillis();
try {
while (System.currentTimeMillis() - start < 100) {
RpcInvokeContext.setContext(mainContext);
}
} catch (RuntimeException e) {
exceptionHolder.set(e);
throw e;
}

}).start();

Map<String, String> headers = RpcInvokeContext.getContext().getCustomHeader();
countDownLatch.countDown();
countDownLatch.await();
long start = System.currentTimeMillis();
int i = 0;
while (System.currentTimeMillis() - start < 100) {
if (exceptionHolder.get() != null) {
throw exceptionHolder.get();
}
headers.put("" + i, "" + i);
i++;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,24 @@ protected void recordClientElapseTime() {

protected void pickupBaggage(SofaResponse response) {
if (RpcInvokeContext.isBaggageEnable()) {
RpcInvokeContext invokeCtx = null;
RpcInvokeContext old = null;
RpcInvokeContext newContext = null;
if (context != null) {
invokeCtx = (RpcInvokeContext) context.getAttachment(RpcConstants.HIDDEN_KEY_INVOKE_CONTEXT);
old = (RpcInvokeContext) context.getAttachment(RpcConstants.HIDDEN_KEY_INVOKE_CONTEXT);
}
if (invokeCtx == null) {
invokeCtx = RpcInvokeContext.getContext();
if (old == null) {
newContext = RpcInvokeContext.getContext();
} else {
RpcInvokeContext.setContext(invokeCtx);
RpcInvokeContext.setContext(old);
newContext = RpcInvokeContext.getContext();
}
BaggageResolver.pickupFromResponse(invokeCtx, response);
BaggageResolver.pickupFromResponse(newContext, response);

if (old != null) {
old.getAllResponseBaggage().putAll(newContext.getAllResponseBaggage());
old.getAllRequestBaggage().putAll(newContext.getAllRequestBaggage());
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ protected void pickupBaggage(SofaResponse response) {
invokeCtx = RpcInvokeContext.getContext();
} else {
RpcInvokeContext.setContext(invokeCtx);
invokeCtx = RpcInvokeContext.getContext();
}
BaggageResolver.pickupFromResponse(invokeCtx, response);
}
Expand Down
Loading

0 comments on commit f6864bb

Please sign in to comment.