Skip to content

Commit

Permalink
more and more RAII
Browse files Browse the repository at this point in the history
  • Loading branch information
guerro323 committed Apr 7, 2022
1 parent ebb1b7a commit f942f24
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 30 deletions.
40 changes: 28 additions & 12 deletions revghost/Utility/NativeAllocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public struct Data
/// <typeparam name="T">The type of the managed object</typeparam>
/// <returns>Return the managed object on native memory</returns>
/// <remarks><see cref="additionalSize"/> can be used for strings (<see cref="NativeAllocatorExtensions.AllocString"/>)</remarks>
public readonly T AllocZeroed<T>(int additionalSize = 0)
public readonly T NewZeroed<T>(int additionalSize = 0)
where T : class
{
var size = (nuint) (((int*) typeof(T).TypeHandle.Value)![1] + additionalSize);
var memory = (byte*) Context->Alloc(ref *Context, ContextManagedObject, size);
Expand All @@ -48,6 +49,7 @@ public readonly T AllocZeroed<T>(int additionalSize = 0)
/// <returns>Return the managed object on native memory</returns>
/// <remarks><see cref="additionalSize"/> can be used for strings (<see cref="NativeAllocatorExtensions.AllocString"/>)</remarks>
public readonly T New<T>(int additionalSize = 0)
where T : class
{
var size = (nuint) (((int*) typeof(T).TypeHandle.Value)![1] + additionalSize);
var memory = (byte*) Context->Alloc(ref *Context, ContextManagedObject, size);
Expand All @@ -74,19 +76,12 @@ public readonly ref byte GetObjectBaseMemory(object obj)
/// <typeparam name="T">The type of the managed object</typeparam>
/// <returns>Whether or not it was successfully freed (if you don't use a tracking allocator the result will always be true)</returns>
public readonly bool Free<T>(ref T obj)
where T : class
{
if (obj == null)
return false;

ref var memory = ref Unsafe.NullRef<byte>();
if (typeof(T).IsValueType)
{
memory = Unsafe.As<T, byte>(ref obj);
}
else
{
memory = ref GetObjectBaseMemory(obj);
}
ref var memory = ref GetObjectBaseMemory(obj);

obj = default;
return Context->Free(ref *Context, ContextManagedObject, Unsafe.AsPointer(ref memory));
Expand Down Expand Up @@ -151,7 +146,7 @@ private static bool DefaultFree(ref Data data, object _, void* memory)
var hashset = Unsafe.As<object, HashSet<IntPtr>>(ref companion);
var memory = NativeMemory.Alloc(size);
hashset.Add((IntPtr) memory);

return memory;
}

Expand All @@ -174,6 +169,7 @@ private static void TrackingDispose(ref Data data, object companion)
{
NativeMemory.Free((void*) ptr);
}
hashset.Clear();
}
}

Expand Down Expand Up @@ -215,9 +211,29 @@ public static string Join(this in NativeAllocator allocator, ReadOnlySpan<char>
return final;
}

public static string Format<T>(this in NativeAllocator allocator, T value, ReadOnlySpan<char> format = default, IFormatProvider? provider = null, int size = 128)
where T : ISpanFormattable
{
var str = NewString(allocator, size);
var span = MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(str.AsSpan()), str.Length);

value.TryFormat(span, out var charsWritten, format, provider);

// resize string
var memorySpan = MemoryMarshal.CreateSpan(
ref allocator.GetObjectBaseMemory(str),
// Header + Length + FirstChar
sizeof(long) + sizeof(int) + sizeof(char)
);
MemoryMarshal.Cast<byte, int>(memorySpan)[2] = charsWritten;

return str;
}

public static GuardAllocation<T> GuardAlloc<T>(this in NativeAllocator allocator)
where T : class
{
return new GuardAllocation<T>(allocator, allocator.AllocZeroed<T>());
return new GuardAllocation<T>(allocator, allocator.NewZeroed<T>());
}
}

Expand Down
185 changes: 167 additions & 18 deletions revghost/Utility/RAII.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;

namespace revghost.Utility;

public static unsafe class RAII
{
public enum DebugReferenceOption
{
None,
Simple,
Full
}

public static readonly NativeAllocator Allocator;

public static bool IsTracking = true;
public static bool DebugMemory = false;
public static DebugReferenceOption DebugReference = DebugReferenceOption.None;

private struct Information
{

}

private static readonly Dictionary<IntPtr, Information> Tracked = new();
private static readonly ConcurrentDictionary<IntPtr, Information> Tracked = new();

private static readonly NativeAllocator.Data MethodTable = new()
{
Expand Down Expand Up @@ -48,14 +60,44 @@ public static bool IsRAIIObject(object obj)

public static void Increment(object obj)
{
GetCounter(obj) += 1;
var result = Interlocked.Increment(ref GetCounter(obj));
if (DebugReference != DebugReferenceOption.None)
{
var additional = string.Empty;
if (DebugReference == DebugReferenceOption.Full)
{
var stackTrace = new StackTrace(1, true);
additional = $"\n{stackTrace}";
}

var ptr = Unsafe.As<object, IntPtr>(ref obj);
HostLogger.Output.Info(
$"Increment {IntPtr.Subtract(ptr, sizeof(int))}, count={result}{additional}",
"RAII",
"ref/inc");
}
}

public static bool Decrement(object obj)
{
ref var counter = ref GetCounter(obj);
counter -= 1;
if (counter <= 0)
var result = Interlocked.Decrement(ref GetCounter(obj));
if (DebugReference != DebugReferenceOption.None)
{
var additional = string.Empty;
if (DebugReference == DebugReferenceOption.Full)
{
var stackTrace = new StackTrace(1, true);
additional = $"\n{stackTrace}";
}

var ptr = Unsafe.As<object, IntPtr>(ref obj);
HostLogger.Output.Info(
$"Decrement {IntPtr.Subtract(ptr, sizeof(int))}, count={result}{additional}",
"RAII",
"ref/dec");
}

if (result <= 0)
{
Allocator.Free(ref obj);
return true;
Expand All @@ -68,34 +110,73 @@ public static bool Decrement(object obj)
{
var memory = (byte*) NativeMemory.Alloc(size + sizeof(int));
Unsafe.AsRef<int>(memory) = 0;

if (DebugMemory)
{
HostLogger.Output.Info(
$"Allocated pointer at: {(IntPtr) memory} (size={size})",
"RAII",
"alloc");
}

Console.WriteLine($"alloc at {(IntPtr) memory}");

Tracked.Add((IntPtr) memory, new Information());
if (!Tracked.TryAdd((IntPtr) memory, new Information()))
throw new InvalidOperationException("invalid synchronization");

return memory + sizeof(int);
}

private static bool Free(ref NativeAllocator.Data data, object companion, void* ptr)
{
var managedPtr = (IntPtr) (byte*) ptr - sizeof(int);
Console.WriteLine($"free at {managedPtr}");

if (Tracked.ContainsKey(managedPtr))
{
NativeMemory.Free((byte*) ptr - sizeof(int));
Tracked.Remove(managedPtr);

Tracked.TryRemove(managedPtr, out _);

if (DebugMemory)
{
HostLogger.Output.Info(
$"Freed tracked pointer: {managedPtr}",
"RAII",
"free/success");
}

return true;
}


if (DebugMemory)
{
HostLogger.Output.Warn(
$"Tried to freed an address that wasn't tracked: {managedPtr}",
"RAII",
"free/not-tracked");
}

return false;
}

private static void Dispose(ref NativeAllocator.Data data, object companion)
{

}

public static void ReadTracked<TList>(TList list)
where TList : IList<TrackedInfo>
{
foreach (var (address, info) in Tracked)
{
list.Add(new TrackedInfo
{
Address = address
});
}
}

public struct TrackedInfo
{
public IntPtr Address;
public StackTrace? StackTrace;
}
}

public unsafe struct RefClass<T> : IDisposable
Expand Down Expand Up @@ -125,15 +206,31 @@ private void StructuralChange()
_creationAddr = currentAddr;
}

public void Create()
internal void CoreCreate()
{
StructuralChange();

_object = RAII.Allocator.New<T>();
RAII.Increment(_object);
}

public void Set(T replace)
[StackTraceHidden]
internal void CoreSetNoIncrement(T replace)
{
StructuralChange();

var previous = _object;

_object = replace;

if (RAII.IsRAIIObject(previous))
{
RAII.Decrement(previous);
}
}

[StackTraceHidden]
internal void CoreSet(T replace)
{
StructuralChange();

Expand All @@ -150,7 +247,27 @@ public void Set(T replace)
RAII.Decrement(previous);
}
}

[StackTraceHidden]
internal void CoreSet(RefClass<T> replace)
{
StructuralChange();

var previous = _object;
if (RAII.IsRAIIObject(replace._object))
{
RAII.Increment(replace._object);
}

_object = replace._object;

if (RAII.IsRAIIObject(previous))
{
RAII.Decrement(previous);
}
}

[StackTraceHidden]
public T Get()
{
if (RAII.IsRAIIObject(_object))
Expand All @@ -166,6 +283,7 @@ public T GetUnsafe()
return _object;
}

[StackTraceHidden]
public TRet Act<TArg, TRet>(Func<T, TArg, TRet> func, TArg arg)
{
TRet ret;
Expand All @@ -183,6 +301,7 @@ public TRet Act<TArg, TRet>(Func<T, TArg, TRet> func, TArg arg)
return ret;
}

[StackTraceHidden]
public TRet Act<TRet>(Func<T, TRet> func)
{
TRet ret;
Expand All @@ -200,6 +319,7 @@ public TRet Act<TRet>(Func<T, TRet> func)
return ret;
}

[StackTraceHidden]
public void Act<TArg>(Action<T, TArg> action, TArg arg)
{
if (RAII.IsRAIIObject(_object))
Expand All @@ -214,6 +334,7 @@ public void Act<TArg>(Action<T, TArg> action, TArg arg)
}
}

[StackTraceHidden]
public void Act(Action<T> action)
{
if (RAII.IsRAIIObject(_object))
Expand All @@ -235,12 +356,40 @@ public void Dispose()
RAII.Decrement(_object);
}
}
}

public static implicit operator RefClass<T>(T obj)
public static class RefClass
{
public static RefClass<T> Return<T>(T obj) where T : class
{
var ret = new RefClass<T>();
ret.Set(obj);

ret.CoreSet(obj);
return ret;
}

public static RefClass<T> Argument<T>(T obj) where T : class
{
var ret = new RefClass<T>();
ret.CoreSetNoIncrement(obj);

return ret;
}
}

public static class RefClassExtension
{
[StackTraceHidden]
public static void Set<T>(this in RefClass<T> refClass, T obj) where T : class
{
ref var bypassReadonly = ref Unsafe.AsRef(in refClass);
bypassReadonly.CoreSet(obj);
}

[StackTraceHidden]
public static void Set<T>(this in RefClass<T> refClass, RefClass<T> other) where T : class
{
ref var bypassReadonly = ref Unsafe.AsRef(in refClass);
bypassReadonly.CoreSet(other);
}
}

0 comments on commit f942f24

Please sign in to comment.