Skip to content

Commit

Permalink
Fix race condition in GetAppdomainStaticAddress test (dotnet#42437)
Browse files Browse the repository at this point in the history
  • Loading branch information
davmason authored Oct 1, 2020
1 parent e7c0978 commit e51be02
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 204 deletions.
110 changes: 110 additions & 0 deletions src/tests/profiler/native/event.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#pragma once

#include <mutex>
#include <condition_variable>
#include <functional>

class AutoEvent
{
private:
std::mutex m_mtx;
std::condition_variable m_cv;
bool m_set = false;

static VOID DoNothing()
{

}

public:
AutoEvent() = default;
~AutoEvent() = default;
AutoEvent(AutoEvent& other) = delete;
AutoEvent(AutoEvent &&other) = delete;
AutoEvent &operator=(AutoEvent &other) = delete;
AutoEvent &operator=(AutoEvent &&other) = delete;

void Wait(std::function<void()> spuriousCallback = DoNothing)
{
std::unique_lock<std::mutex> lock(m_mtx);
while (!m_set)
{
m_cv.wait(lock, [&]() { return m_set; });
if (!m_set)
{
spuriousCallback();
}
}
m_set = false;
}

void WaitFor(int milliseconds, std::function<void()> spuriousCallback = DoNothing)
{
std::unique_lock<std::mutex> lock(m_mtx);
while (!m_set)
{
m_cv.wait_for(lock, std::chrono::milliseconds(milliseconds), [&]() { return m_set; });
if (!m_set)
{
spuriousCallback();
}
}
m_set = false;
}

void Signal()
{
std::unique_lock<std::mutex> lock(m_mtx);
m_set = true;
m_cv.notify_one();
}
};

class ManualEvent
{
private:
std::mutex m_mtx;
std::condition_variable m_cv;
bool m_set = false;

static VOID DoNothing()
{

}

public:
ManualEvent() = default;
~ManualEvent() = default;
ManualEvent(ManualEvent& other) = delete;
ManualEvent(ManualEvent&& other) = delete;
ManualEvent& operator= (ManualEvent& other) = delete;
ManualEvent& operator= (ManualEvent&& other) = delete;

void Wait(std::function<void()> spuriousCallback = DoNothing)
{
std::unique_lock<std::mutex> lock(m_mtx);
while (!m_set)
{
m_cv.wait(lock, [&]() { return m_set; });
if (!m_set)
{
spuriousCallback();
}
}
}

void Signal()
{
std::unique_lock<std::mutex> lock(m_mtx);
m_set = true;
}

void Reset()
{
std::unique_lock<std::mutex> lock(m_mtx);
m_set = false;
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ HRESULT GetAppDomainStaticAddress::Initialize(IUnknown *pICorProfilerInfoUnk)

while (true)
{
std::this_thread::sleep_for(std::chrono::milliseconds(100));

gcWaitEvent.Wait();

if (!IsRuntimeExecutingManagedCode())
Expand All @@ -96,7 +94,6 @@ HRESULT GetAppDomainStaticAddress::Initialize(IUnknown *pICorProfilerInfoUnk)
};

gcTriggerThread = thread(gcTriggerLambda);
gcWaitEvent.Signal();

return S_OK;
}
Expand All @@ -105,8 +102,6 @@ HRESULT GetAppDomainStaticAddress::Shutdown()
{
Profiler::Shutdown();

gcWaitEvent.Reset();

if (this->pCorProfilerInfo != nullptr)
{
this->pCorProfilerInfo->Release();
Expand Down Expand Up @@ -144,103 +139,106 @@ HRESULT GetAppDomainStaticAddress::ModuleLoadFinished(ModuleID moduleId, HRESULT
++failures;
}

if (DEBUG_OUT)
{
wprintf(L"Module 0x%" PRIxPTR " (%s) loaded\n", moduleId, name);
}

wprintf(L"Module 0x%" PRIxPTR " (%s) loaded\n", moduleId, name);

printf("Forcing GC due to module load\n");
gcWaitEvent.Signal();
return S_OK;
}

HRESULT GetAppDomainStaticAddress::ModuleUnloadStarted(ModuleID moduleId)
{
lock_guard<mutex> guard(classADMapLock);
constexpr size_t nameLen = 1024;
WCHAR name[nameLen];
HRESULT hr = pCorProfilerInfo->GetModuleInfo2(moduleId,
NULL,
nameLen,
NULL,
name,
NULL,
NULL);
if (FAILED(hr))
{
printf("GetModuleInfo2 failed with hr=0x%x\n", hr);
++failures;
return E_FAIL;
}

if (DEBUG_OUT)
{
wprintf(L"Module 0x%" PRIxPTR " (%s) unload started\n", moduleId, name);
}
{
printf("Forcing GC due to module unload\n");
gcWaitEvent.Signal();

for (auto it = classADMap.begin(); it != classADMap.end(); )
{
ClassID classId = it->first;

ModuleID modId;
hr = pCorProfilerInfo->GetClassIDInfo(classId, &modId, NULL);
lock_guard<mutex> guard(classADMapLock);
constexpr size_t nameLen = 1024;
WCHAR name[nameLen];
HRESULT hr = pCorProfilerInfo->GetModuleInfo2(moduleId,
NULL,
nameLen,
NULL,
name,
NULL,
NULL);
if (FAILED(hr))
{
printf("Failed to get ClassIDInfo hr=0x%x\n", hr);
printf("GetModuleInfo2 failed with hr=0x%x\n", hr);
++failures;
return E_FAIL;
}

if (modId == moduleId)
{
if (DEBUG_OUT)
{
printf("ClassID 0x%" PRIxPTR " being removed due to parent module unloading\n", classId);
}

it = classADMap.erase(it);
continue;
}
wprintf(L"Module 0x%" PRIxPTR " (%s) unload started\n", moduleId, name);

// Now check the generic arguments
bool shouldEraseClassId = false;
vector<ClassID> genericTypes = GetGenericTypeArgs(classId);
for (auto genericIt = genericTypes.begin(); genericIt != genericTypes.end(); ++genericIt)
for (auto it = classADMap.begin(); it != classADMap.end(); )
{
ClassID typeArg = *genericIt;
ModuleID typeArgModId;
ClassID classId = it->first;

if (DEBUG_OUT)
{
printf("Checking generic argument 0x%" PRIxPTR " of class 0x%" PRIxPTR "\n", typeArg, classId);
}

hr = pCorProfilerInfo->GetClassIDInfo(typeArg, &typeArgModId, NULL);
ModuleID modId;
hr = pCorProfilerInfo->GetClassIDInfo(classId, &modId, NULL);
if (FAILED(hr))
{
printf("Failed to get ClassIDInfo hr=0x%x\n", hr);
++failures;
return E_FAIL;
}

if (typeArgModId == moduleId)
if (modId == moduleId)
{
if (DEBUG_OUT)
{
wprintf(L"ClassID 0x%" PRIxPTR " (%s) being removed due to generic argument 0x%" PRIxPTR " (%s) belonging to the parent module 0x%" PRIxPTR " unloading\n",
classId, GetClassIDName(classId).ToCStr(), typeArg, GetClassIDName(typeArg).ToCStr(), typeArgModId);
printf("ClassID 0x%" PRIxPTR " being removed due to parent module unloading\n", classId);
}

shouldEraseClassId = true;
break;
it = classADMap.erase(it);
continue;
}
}

if (shouldEraseClassId)
{
it = classADMap.erase(it);
}
else
{
++it;
// Now check the generic arguments
bool shouldEraseClassId = false;
vector<ClassID> genericTypes = GetGenericTypeArgs(classId);
for (auto genericIt = genericTypes.begin(); genericIt != genericTypes.end(); ++genericIt)
{
ClassID typeArg = *genericIt;
ModuleID typeArgModId;

if (DEBUG_OUT)
{
printf("Checking generic argument 0x%" PRIxPTR " of class 0x%" PRIxPTR "\n", typeArg, classId);
}

hr = pCorProfilerInfo->GetClassIDInfo(typeArg, &typeArgModId, NULL);
if (FAILED(hr))
{
printf("Failed to get ClassIDInfo hr=0x%x\n", hr);
++failures;
return E_FAIL;
}

if (typeArgModId == moduleId)
{
if (DEBUG_OUT)
{
wprintf(L"ClassID 0x%" PRIxPTR " (%s) being removed due to generic argument 0x%" PRIxPTR " (%s) belonging to the parent module 0x%" PRIxPTR " unloading\n",
classId, GetClassIDName(classId).ToCStr(), typeArg, GetClassIDName(typeArg).ToCStr(), typeArgModId);
}

shouldEraseClassId = true;
break;
}
}

if (shouldEraseClassId)
{
it = classADMap.erase(it);
}
else
{
++it;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,12 @@
#include <string>
#include <thread>
#include <chrono>
#include <condition_variable>
#include <functional>
#include "cor.h"
#include "corprof.h"

typedef HRESULT (*GetDispenserFunc) (const CLSID &pClsid, const IID &pIid, void **ppv);

class ManualEvent
{
private:
std::mutex m_mtx;
std::condition_variable m_cv;
bool m_set = false;

static void DoNothing()
{

}

public:
ManualEvent() = default;
~ManualEvent() = default;
ManualEvent(ManualEvent& other) = delete;
ManualEvent(ManualEvent&& other) = delete;
ManualEvent& operator= (ManualEvent& other) = delete;
ManualEvent& operator= (ManualEvent&& other) = delete;

void Wait(std::function<void()> spuriousCallback = DoNothing)
{
std::unique_lock<std::mutex> lock(m_mtx);
while (!m_set)
{
m_cv.wait(lock, [&]() { return m_set; });
if (!m_set)
{
spuriousCallback();
}
}
}

void Signal()
{
std::unique_lock<std::mutex> lock(m_mtx);
m_set = true;
}

void Reset()
{
std::unique_lock<std::mutex> lock(m_mtx);
m_set = false;
}
};

class GetAppDomainStaticAddress : public Profiler
{
public:
Expand All @@ -93,7 +46,7 @@ class GetAppDomainStaticAddress : public Profiler

std::atomic<int> jitEventCount;
std::thread gcTriggerThread;
ManualEvent gcWaitEvent;
AutoEvent gcWaitEvent;

typedef std::map<ClassID, AppDomainID>ClassAppDomainMap;
ClassAppDomainMap classADMap;
Expand Down
Loading

0 comments on commit e51be02

Please sign in to comment.