Skip to content

Commit

Permalink
Delete reflection blocking on GetThreadDeserializationTracker (dotnet…
Browse files Browse the repository at this point in the history
  • Loading branch information
jkotas authored Jan 6, 2021
1 parent 5c34679 commit 9053d5d
Show file tree
Hide file tree
Showing 9 changed files with 0 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ partial void ThreadNameChanged(string? value)
[DllImport(RuntimeHelpers.QCall, CharSet = CharSet.Unicode)]
private static extern void InformThreadNameChange(ThreadHandle t, string? name, int len);

[MethodImpl(MethodImplOptions.InternalCall)]
internal static extern DeserializationTracker GetThreadDeserializationTracker(ref StackCrawlMark stackMark);

/// <summary>Returns true if the thread has been started and is not dead.</summary>
public extern bool IsAlive
{
Expand Down
24 changes: 0 additions & 24 deletions src/coreclr/vm/comsynchronizable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1158,30 +1158,6 @@ BOOL QCALLTYPE ThreadNative::YieldThread()
return ret;
}

FCIMPL1(Object*, ThreadNative::GetThreadDeserializationTracker, StackCrawlMark* stackMark)
{
FCALL_CONTRACT;
OBJECTREF refRetVal = NULL;
HELPER_METHOD_FRAME_BEGIN_RET_1(refRetVal)

// To avoid reflection trying to bypass deserialization tracking, check the caller
// and only allow SerializationInfo to call into this method.
MethodTable* pCallerMT = SystemDomain::GetCallersType(stackMark);
if (pCallerMT != CoreLibBinder::GetClass(CLASS__SERIALIZATION_INFO))
{
COMPlusThrowArgumentException(W("stackMark"), NULL);
}

Thread* pThread = GetThread();

refRetVal = ObjectFromHandle(pThread->GetOrCreateDeserializationTracker());

HELPER_METHOD_FRAME_END();

return OBJECTREFToObject(refRetVal);
}
FCIMPLEND

FCIMPL0(INT32, ThreadNative::GetCurrentProcessorNumber)
{
FCALL_CONTRACT;
Expand Down
1 change: 0 additions & 1 deletion src/coreclr/vm/comsynchronizable.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ friend class ThreadBaseObject;
#endif //FEATURE_COMINTEROP
static FCDECL1(FC_BOOL_RET,IsThreadpoolThread, ThreadBaseObject* thread);
static FCDECL1(void, SetIsThreadpoolThread, ThreadBaseObject* thread);
static FCDECL1(Object*, GetThreadDeserializationTracker, StackCrawlMark* stackMark);

static FCDECL0(INT32, GetCurrentProcessorNumber);

Expand Down
3 changes: 0 additions & 3 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,6 @@ DEFINE_METHOD(COMWRAPPERS, RELEASE_OBJECTS, CallReleaseObjects,
DEFINE_METHOD(COMWRAPPERS, CALL_ICUSTOMQUERYINTERFACE, CallICustomQueryInterface, SM_Obj_RefGuid_RefIntPtr_RetInt)
#endif //FEATURE_COMINTEROP

DEFINE_CLASS(SERIALIZATION_INFO, Serialization, SerializationInfo)
DEFINE_CLASS(DESERIALIZATION_TRACKER, Serialization, DeserializationTracker)

DEFINE_CLASS(IENUMERATOR, Collections, IEnumerator)

DEFINE_CLASS(IENUMERABLE, Collections, IEnumerable)
Expand Down
1 change: 0 additions & 1 deletion src/coreclr/vm/ecalllist.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ FCFuncStart(gThreadFuncs)
FCFuncElement("Join", ThreadNative::Join)
QCFuncElement("GetOptimalMaxSpinWaitsPerSpinIterationInternal", ThreadNative::GetOptimalMaxSpinWaitsPerSpinIteration)
FCFuncElement("GetCurrentProcessorNumber", ThreadNative::GetCurrentProcessorNumber)
FCFuncElement("GetThreadDeserializationTracker", ThreadNative::GetThreadDeserializationTracker)
FCFuncEnd()

FCFuncStart(gThreadPoolFuncs)
Expand Down
33 changes: 0 additions & 33 deletions src/coreclr/vm/threads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,6 @@ Thread::Thread()
memset(&m_activityId, 0, sizeof(m_activityId));
#endif // FEATURE_PERFTRACING
m_HijackReturnKind = RT_Illegal;
m_DeserializationTracker = NULL;

m_currentPrepareCodeConfig = nullptr;
m_isInForbidSuspendForDebuggerRegion = false;
Expand Down Expand Up @@ -2630,11 +2629,6 @@ Thread::~Thread()
// Destroy any handles that we're using to hold onto exception objects
SafeSetThrowables(NULL);

if (m_DeserializationTracker != NULL)
{
DestroyGlobalStrongHandle(m_DeserializationTracker);
}

DestroyShortWeakHandle(m_ExposedObject);
DestroyStrongHandle(m_StrongHndToExposedObject);
}
Expand Down Expand Up @@ -8544,30 +8538,3 @@ ThreadStore::EnumMemoryRegions(CLRDataEnumMemoryFlags flags)
}

#endif // #ifdef DACCESS_COMPILE

OBJECTHANDLE Thread::GetOrCreateDeserializationTracker()
{
CONTRACTL
{
THROWS;
GC_TRIGGERS;
MODE_COOPERATIVE;
}
CONTRACTL_END;

#if !defined (DACCESS_COMPILE)
if (m_DeserializationTracker != NULL)
{
return m_DeserializationTracker;
}

_ASSERTE(this == GetThread());

MethodTable* pMT = CoreLibBinder::GetClass(CLASS__DESERIALIZATION_TRACKER);
m_DeserializationTracker = CreateGlobalStrongHandle(AllocateObject(pMT));

_ASSERTE(m_DeserializationTracker != NULL);
#endif // !defined (DACCESS_COMPILE)

return m_DeserializationTracker;
}
6 changes: 0 additions & 6 deletions src/coreclr/vm/threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -4605,12 +4605,6 @@ class Thread
}
#endif // FEATURE_HIJACK

public:
OBJECTHANDLE GetOrCreateDeserializationTracker();

private:
OBJECTHANDLE m_DeserializationTracker;

public:
static uint64_t dead_threads_non_alloc_bytes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,23 @@ public sealed partial class SerializationInfo
{
internal static AsyncLocal<bool> AsyncDeserializationInProgress { get; } = new AsyncLocal<bool>();

#if !CORECLR
// On AoT, assume private members are reflection blocked, so there's no further protection required
// for the thread's DeserializationTracker
[ThreadStatic]
private static DeserializationTracker? t_deserializationTracker;

private static DeserializationTracker GetThreadDeserializationTracker() =>
t_deserializationTracker ??= new DeserializationTracker();
#endif // !CORECLR

// Returns true if deserialization is currently in progress
public static bool DeserializationInProgress
{
#if CORECLR
[DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod
#endif
get
{
if (AsyncDeserializationInProgress.Value)
{
return true;
}

#if CORECLR
StackCrawlMark stackMark = StackCrawlMark.LookForMe;
DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark);
#else
DeserializationTracker tracker = GetThreadDeserializationTracker();
#endif
bool result = tracker.DeserializationInProgress;
return result;
}
Expand Down Expand Up @@ -100,19 +88,11 @@ public static void ThrowIfDeserializationInProgress(string switchSuffix, ref int
// In this state, if the SerializationGuard or other related AppContext switches are set,
// actions likely to be dangerous during deserialization, such as starting a process will be blocked.
// Returns a DeserializationToken that must be disposed to remove the deserialization state.
#if CORECLR
[DynamicSecurityMethod] // Methods containing StackCrawlMark local var must be marked DynamicSecurityMethod
#endif
public static DeserializationToken StartDeserialization()
{
if (LocalAppContextSwitches.SerializationGuard)
{
#if CORECLR
StackCrawlMark stackMark = StackCrawlMark.LookForMe;
DeserializationTracker tracker = Thread.GetThreadDeserializationTracker(ref stackMark);
#else
DeserializationTracker tracker = GetThreadDeserializationTracker();
#endif
if (!tracker.DeserializationInProgress)
{
lock (tracker)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,6 @@ public static void BlockFileWrites()
TryPayload(new FileWriter());
}

[Fact]
[ActiveIssue("https://github.com/mono/mono/issues/15112", TestRuntimes.Mono)]
public static void BlockReflectionDodging()
{
// Ensure that the deserialization tracker cannot be called by reflection.
MethodInfo trackerMethod = typeof(Thread).GetMethod(
"GetThreadDeserializationTracker",
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static);

Assert.NotNull(trackerMethod);

Assert.Equal(1, trackerMethod.GetParameters().Length);
object[] args = new object[1];
args[0] = Enum.ToObject(typeof(Thread).Assembly.GetType("System.Threading.StackCrawlMark"), 0);

try
{
object tracker = trackerMethod.Invoke(null, args);
throw new InvalidOperationException(tracker?.ToString() ?? "(null tracker returned)");
}
catch (TargetInvocationException ex)
{
Exception baseEx = ex.GetBaseException();
AssertExtensions.Throws<ArgumentException>("stackMark", () => throw baseEx);
}
}

[Fact]
public static void BlockAsyncDodging()
{
Expand Down

0 comments on commit 9053d5d

Please sign in to comment.