Skip to content

Commit

Permalink
wasm: create threadstatics using GetThreadStaticBaseForType (dotnet#6769
Browse files Browse the repository at this point in the history
)

Previously the support for `[ThreadStatic]` was incomplete.  This change adds some more tests (the last 3 of which fail prior to these changes).  Fixes dotnet#6733 - I've included a call to `Thread.Sleep` to test that the assert is no longer hit.  The sleep time is short so as not to affect the CI, but if its changed to something noticeable, like 10s, it can be seen to take effect.  

Also included is a small comment typo.
  • Loading branch information
yowl authored and MichalStrehovsky committed Jan 11, 2019
1 parent 21bd5a1 commit ef7c9e6
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 35 deletions.
122 changes: 88 additions & 34 deletions src/ILCompiler.WebAssembly/src/CodeGen/ILToWebAssemblyImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,11 @@ private void GenerateProlog()
}

MetadataType metadataType = (MetadataType)_thisType;
if (!metadataType.IsBeforeFieldInit)
if (!metadataType.IsBeforeFieldInit
&& (!_method.IsStaticConstructor && _method.Signature.IsStatic || _method.IsConstructor || (_thisType.IsValueType && !_method.Signature.IsStatic))
&& _compilation.TypeSystemContext.HasLazyStaticConstructor(metadataType))
{
if (!_method.IsStaticConstructor && _method.Signature.IsStatic || _method.IsConstructor || (_thisType.IsValueType && !_method.Signature.IsStatic))
{
TriggerCctor(metadataType);
}
TriggerCctor(metadataType);
}

LLVMBasicBlockRef block0 = GetLLVMBasicBlockForBlock(_basicBlocks[0]);
Expand Down Expand Up @@ -2883,12 +2882,19 @@ private LLVMValueRef GetFieldAddress(FieldDesc field, bool isStatic)
MetadataType owningType = (MetadataType)field.OwningType;
LLVMValueRef staticBase;
int fieldOffset;
// If the type is non-BeforeFieldInit, this is handled before calling any methods on it
bool needsCctorCheck = (owningType.IsBeforeFieldInit || (!owningType.IsBeforeFieldInit && owningType != _thisType)) && _compilation.TypeSystemContext.HasLazyStaticConstructor(owningType);

if (field.HasRva)
{
node = (ISymbolNode)_compilation.GetFieldRvaData(field);
staticBase = LoadAddressOfSymbolNode(node);
fieldOffset = 0;
// Run static constructor if necessary
if (needsCctorCheck)
{
TriggerCctor(owningType);
}
}
else
{
Expand All @@ -2897,36 +2903,38 @@ private LLVMValueRef GetFieldAddress(FieldDesc field, bool isStatic)
if (field.IsThreadStatic)
{
// TODO: We need the right thread static per thread
node = _compilation.NodeFactory.TypeThreadStaticsSymbol(owningType);
staticBase = LoadAddressOfSymbolNode(node);
}
else if (field.HasGCStaticBase)
{
node = _compilation.NodeFactory.TypeGCStaticsSymbol(owningType);

// We can't use GCStatics in the data section until we can successfully call
// InitializeModules on startup, so stick with globals for now
//LLVMValueRef basePtrPtr = LoadAddressOfSymbolNode(node);
//staticBase = LLVM.BuildLoad(_builder, LLVM.BuildLoad(_builder, LLVM.BuildPointerCast(_builder, basePtrPtr, LLVM.PointerType(LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type(), 0), 0), 0), "castBasePtrPtr"), "basePtr"), "base");
staticBase = WebAssemblyObjectWriter.EmitGlobal(Module, field, _compilation.NameMangler);
fieldOffset = 0;
ExpressionEntry returnExp;
node = TriggerCctorWithThreadStaticStorage(owningType, needsCctorCheck, out returnExp);
staticBase = returnExp.ValueAsType(returnExp.Type, _builder);
}
else
{
node = _compilation.NodeFactory.TypeNonGCStaticsSymbol(owningType);
staticBase = LoadAddressOfSymbolNode(node);
if (field.HasGCStaticBase)
{
node = _compilation.NodeFactory.TypeGCStaticsSymbol(owningType);

// We can't use GCStatics in the data section until we can successfully call
// InitializeModules on startup, so stick with globals for now
//LLVMValueRef basePtrPtr = LoadAddressOfSymbolNode(node);
//staticBase = LLVM.BuildLoad(_builder, LLVM.BuildLoad(_builder, LLVM.BuildPointerCast(_builder, basePtrPtr, LLVM.PointerType(LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type(), 0), 0), 0), "castBasePtrPtr"), "basePtr"), "base");
staticBase = WebAssemblyObjectWriter.EmitGlobal(Module, field, _compilation.NameMangler);
fieldOffset = 0;
}
else
{
node = _compilation.NodeFactory.TypeNonGCStaticsSymbol(owningType);
staticBase = LoadAddressOfSymbolNode(node);
}
// Run static constructor if necessary
if (needsCctorCheck)
{
TriggerCctor(owningType);
}
}
}

_dependencies.Add(node);

// Run static constructor if necessary
// If the type is non-BeforeFieldInit, this is handled before calling any methods on it
if (owningType.IsBeforeFieldInit || (!owningType.IsBeforeFieldInit && owningType != _thisType))
{
TriggerCctor(owningType);
}

LLVMValueRef castStaticBase = LLVM.BuildPointerCast(_builder, staticBase, LLVM.PointerType(LLVM.Int8Type(), 0), owningType.Name + "_statics");
LLVMValueRef fieldAddr = LLVM.BuildGEP(_builder, castStaticBase, new LLVMValueRef[] { BuildConstInt32(fieldOffset) }, field.Name + "_addr");

Expand All @@ -2944,7 +2952,31 @@ private LLVMValueRef GetFieldAddress(FieldDesc field, bool isStatic)
/// </summary>
private void TriggerCctor(MetadataType type)
{
if (_compilation.TypeSystemContext.HasLazyStaticConstructor(type))
ISymbolNode classConstructionContextSymbol = _compilation.NodeFactory.TypeNonGCStaticsSymbol(type);
_dependencies.Add(classConstructionContextSymbol);
LLVMValueRef firstNonGcStatic = LoadAddressOfSymbolNode(classConstructionContextSymbol);

// TODO: Codegen could check whether it has already run rather than calling into EnsureClassConstructorRun
// but we'd have to figure out how to manage the additional basic blocks
LLVMValueRef classConstructionContextPtr = LLVM.BuildGEP(_builder, firstNonGcStatic, new LLVMValueRef[] { BuildConstInt32(-2) }, "classConstructionContext");
StackEntry classConstructionContext = new AddressExpressionEntry(StackValueKind.NativeInt, "classConstructionContext", classConstructionContextPtr, GetWellKnownType(WellKnownType.IntPtr));
CallRuntime("System.Runtime.CompilerServices", _compilation.TypeSystemContext, ClassConstructorRunner, "EnsureClassConstructorRun", new StackEntry[] { classConstructionContext });
}

/// <summary>
/// Triggers creation of thread static storage and the static constructor if present
/// </summary>
private ISymbolNode TriggerCctorWithThreadStaticStorage(MetadataType type, bool needsCctorCheck, out ExpressionEntry returnExp)
{
ISymbolNode threadStaticIndexSymbol = _compilation.NodeFactory.TypeThreadStaticIndex(type);
LLVMValueRef threadStaticIndex = LoadAddressOfSymbolNode(threadStaticIndexSymbol);

StackEntry typeManagerSlotEntry = new LoadExpressionEntry(StackValueKind.ValueType, "typeManagerSlot", threadStaticIndex, GetWellKnownType(WellKnownType.Int32));
LLVMValueRef typeTlsIndexPtr =
LLVM.BuildGEP(_builder, threadStaticIndex, new LLVMValueRef[] { BuildConstInt32(1) }, "typeTlsIndexPtr"); // index is the second field after the ptr.
StackEntry tlsIndexExpressionEntry = new LoadExpressionEntry(StackValueKind.ValueType, "typeTlsIndex", typeTlsIndexPtr, GetWellKnownType(WellKnownType.Int32));

if (needsCctorCheck)
{
ISymbolNode classConstructionContextSymbol = _compilation.NodeFactory.TypeNonGCStaticsSymbol(type);
_dependencies.Add(classConstructionContextSymbol);
Expand All @@ -2953,10 +2985,25 @@ private void TriggerCctor(MetadataType type)
// TODO: Codegen could check whether it has already run rather than calling into EnsureClassConstructorRun
// but we'd have to figure out how to manage the additional basic blocks
LLVMValueRef classConstructionContextPtr = LLVM.BuildGEP(_builder, firstNonGcStatic, new LLVMValueRef[] { BuildConstInt32(-2) }, "classConstructionContext");
StackEntry classConstructionContext = new AddressExpressionEntry(StackValueKind.NativeInt, "classConstructionContext", classConstructionContextPtr, GetWellKnownType(WellKnownType.IntPtr));
MetadataType helperType = _compilation.TypeSystemContext.SystemModule.GetKnownType("System.Runtime.CompilerServices", "ClassConstructorRunner");
MethodDesc helperMethod = helperType.GetKnownMethod("EnsureClassConstructorRun", null);
HandleCall(helperMethod, helperMethod.Signature, new StackEntry[] { classConstructionContext });
StackEntry classConstructionContext = new AddressExpressionEntry(StackValueKind.NativeInt, "classConstructionContext", classConstructionContextPtr,
GetWellKnownType(WellKnownType.IntPtr));

returnExp = CallRuntime("System.Runtime.CompilerServices", _compilation.TypeSystemContext, ClassConstructorRunner, "CheckStaticClassConstructionReturnThreadStaticBase", new StackEntry[]
{
typeManagerSlotEntry,
tlsIndexExpressionEntry,
classConstructionContext
});
return threadStaticIndexSymbol;
}
else
{
returnExp = CallRuntime("Internal.Runtime", _compilation.TypeSystemContext, ThreadStatics, "GetThreadStaticBaseForType", new StackEntry[]
{
typeManagerSlotEntry,
tlsIndexExpressionEntry
});
return threadStaticIndexSymbol;
}
}

Expand Down Expand Up @@ -3205,12 +3252,19 @@ private void ImportFallthrough(BasicBlock next)
private const string InternalCalls = "InternalCalls";
private const string TypeCast = "TypeCast";
private const string DispatchResolve = "DispatchResolve";
private const string ThreadStatics = "ThreadStatics";
private const string ClassConstructorRunner = "ClassConstructorRunner";

private ExpressionEntry CallRuntime(TypeSystemContext context, string className, string methodName, StackEntry[] arguments, TypeDesc forcedReturnType = null)
{
MetadataType helperType = context.SystemModule.GetKnownType("System.Runtime", className);
return CallRuntime("System.Runtime", context, className, methodName, arguments, forcedReturnType);
}

private ExpressionEntry CallRuntime(string @namespace, TypeSystemContext context, string className, string methodName, StackEntry[] arguments, TypeDesc forcedReturnType = null)
{
MetadataType helperType = context.SystemModule.GetKnownType(@namespace, className);
MethodDesc helperMethod = helperType.GetKnownMethod(methodName, null);
if((helperMethod.IsInternalCall && helperMethod.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute")))
if ((helperMethod.IsInternalCall && helperMethod.HasCustomAttribute("System.Runtime", "RuntimeImportAttribute")))
return ImportRawPInvoke(helperMethod, arguments, forcedReturnType: forcedReturnType);
else
return HandleCall(helperMethod, helperMethod.Signature, arguments, forcedReturnType: forcedReturnType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace System.Runtime.CompilerServices
{
//=========================================================================================================
// This is the non-portable part of ClassConstructorRunner. It lives in a seaparate .cs file to make
// This is the non-portable part of ClassConstructorRunner. It lives in a separate .cs file to make
// it easier to include the main ClassConstructorRunner source into a desktop project for testing.
//=========================================================================================================

Expand Down
116 changes: 116 additions & 0 deletions tests/src/Simple/HelloWasm/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Threading;
using System.Runtime.InteropServices;
using System.Collections.Generic;
#if PLATFORM_WINDOWS
Expand Down Expand Up @@ -314,6 +315,8 @@ private static unsafe int Main(string[] args)

TestNativeCallback();

TestThreadStaticsForSingleThread();

// This test should remain last to get other results before stopping the debugger
PrintLine("Debugger.Break() test: Ok if debugger is open and breaks.");
System.Diagnostics.Debugger.Break();
Expand Down Expand Up @@ -727,6 +730,71 @@ private static void _CallMe(int x)
[System.Runtime.InteropServices.DllImport("*")]
private static extern void CallMe(int x);

private static void TestThreadStaticsForSingleThread()
{
var firstClass = new ClassWithFourThreadStatics();
int firstClassStatic = firstClass.GetStatic();
PrintString("Static should be initialised: ");
if (firstClassStatic == 2)
{
PrintLine("Ok.");
}
else
{
PrintLine("Failed.");
PrintLine("Was: " + firstClassStatic.ToString());
}
PrintString("Second class with same statics should be initialised: ");
int secondClassStatic = new AnotherClassWithFourThreadStatics().GetStatic();
if (secondClassStatic == 13)
{
PrintLine("Ok.");
}
else
{
PrintLine("Failed.");
PrintLine("Was: " + secondClassStatic.ToString());
}

PrintString("First class increment statics: ");
firstClass.IncrementStatics();
firstClassStatic = firstClass.GetStatic();
if (firstClassStatic == 3)
{
PrintLine("Ok.");
}
else
{
PrintLine("Failed.");
PrintLine("Was: " + firstClassStatic.ToString());
}

PrintString("Second class should not be overwritten: "); // catches a type of bug where beacuse the 2 types share the same number and types of ThreadStatics, the first class can end up overwriting the second
secondClassStatic = new AnotherClassWithFourThreadStatics().GetStatic();
if (secondClassStatic == 13)
{
PrintLine("Ok.");
}
else
{
PrintLine("Failed.");
PrintLine("Was: " + secondClassStatic.ToString());
}

PrintString("First class 2nd instance should share static: ");
int secondInstanceOfFirstClassStatic = new ClassWithFourThreadStatics().GetStatic();
if (secondInstanceOfFirstClassStatic == 3)
{
PrintLine("Ok.");
}
else
{
PrintLine("Failed.");
PrintLine("Was: " + secondInstanceOfFirstClassStatic.ToString());
}
Thread.Sleep(10);
}

[DllImport("*")]
private static unsafe extern int printf(byte* str, byte* unused);
}
Expand Down Expand Up @@ -991,6 +1059,54 @@ interface ISomeItf
int GetValue();
}

class ClassWithFourThreadStatics
{
[ThreadStatic] static int classStatic;
[ThreadStatic] static int classStatic2 = 2;
[ThreadStatic] static int classStatic3;
[ThreadStatic] static int classStatic4;
[ThreadStatic] static int classStatic5;

public int GetStatic()
{
return classStatic2;
}

public void IncrementStatics()
{
classStatic++;
classStatic2++;
classStatic3++;
classStatic4++;
classStatic5++;
}
}

class AnotherClassWithFourThreadStatics
{
[ThreadStatic] static int classStatic = 13;
[ThreadStatic] static int classStatic2;
[ThreadStatic] static int classStatic3;
[ThreadStatic] static int classStatic4;
[ThreadStatic] static int classStatic5;

public int GetStatic()
{
return classStatic;
}

/// <summary>
/// stops field unused compiler error, but never called
/// </summary>
public void IncrementStatics()
{
classStatic2++;
classStatic3++;
classStatic4++;
classStatic5++;
}
}

namespace System.Runtime.InteropServices
{
/// <summary>
Expand Down

0 comments on commit ef7c9e6

Please sign in to comment.