forked from Cysharp/MemoryPack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathExtensions.cs
146 lines (130 loc) · 5.03 KB
/
Extensions.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace MemoryPack.Generator;
internal static class Extensions
{
public static string NewLine(this IEnumerable<string> source)
{
return string.Join(Environment.NewLine, source);
}
public static bool ContainsAttribute(this ISymbol symbol, INamedTypeSymbol attribtue)
{
return symbol.GetAttributes().Any(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass, attribtue));
}
public static AttributeData? GetAttribute(this ISymbol symbol, INamedTypeSymbol attribtue)
{
return symbol.GetAttributes().FirstOrDefault(x => SymbolEqualityComparer.Default.Equals(x.AttributeClass, attribtue));
}
public static AttributeData? GetImplAttribute(this ISymbol symbol, INamedTypeSymbol implAttribtue)
{
return symbol.GetAttributes().FirstOrDefault(x =>
{
if (x.AttributeClass == null) return false;
if (x.AttributeClass.EqualsUnconstructedGenericType(implAttribtue)) return true;
foreach (var item in x.AttributeClass.GetAllBaseTypes())
{
if (item.EqualsUnconstructedGenericType(implAttribtue))
{
return true;
}
}
return false;
});
}
public static IEnumerable<ISymbol> GetAllMembers(this INamedTypeSymbol symbol, bool withoutOverride = true)
{
// Iterate Parent -> Derived
if (symbol.BaseType != null)
{
foreach (var item in GetAllMembers(symbol.BaseType))
{
// override item already iterated in parent type
if (!withoutOverride || !item.IsOverride)
{
yield return item;
}
}
}
foreach (var item in symbol.GetMembers())
{
if (!withoutOverride || !item.IsOverride)
{
yield return item;
}
}
}
public static bool TryGetMemoryPackableType(this ITypeSymbol symbol, ReferenceSymbols references, out GenerateType generateType, out SerializeLayout serializeLayout)
{
var packableCtorArgs = symbol.GetAttribute(references.MemoryPackableAttribute)?.ConstructorArguments;
generateType = GenerateType.Object;
serializeLayout = SerializeLayout.Sequential;
if (packableCtorArgs == null)
{
generateType = GenerateType.NoGenerate;
serializeLayout = SerializeLayout.Sequential;
return false;
}
else if (packableCtorArgs.Value.Length != 0)
{
// MemoryPackable has two attribtue
if (packableCtorArgs.Value.Length == 1)
{
// (SerializeLayout serializeLayout)
var ctorValue = packableCtorArgs.Value[0];
serializeLayout = (SerializeLayout)(ctorValue.Value ?? SerializeLayout.Sequential);
generateType = GenerateType.Object;
}
else
{
// (GenerateType generateType = GenerateType.Object, SerializeLayout serializeLayout = SerializeLayout.Sequential)
generateType = (GenerateType)(packableCtorArgs.Value[0].Value ?? GenerateType.Object);
serializeLayout = (SerializeLayout)(packableCtorArgs.Value[1].Value ?? SerializeLayout.Sequential);
if (generateType is GenerateType.VersionTolerant or GenerateType.CircularReference)
{
serializeLayout = SerializeLayout.Explicit; // version-torelant, always explicit.
}
}
}
if (symbol.IsStatic || symbol.IsAbstract)
{
// static or abstract class is Union
return false;
}
return true;
}
public static bool IsWillImplementMemoryPackUnion(this ITypeSymbol symbol, ReferenceSymbols references)
{
return symbol.IsAbstract && symbol.ContainsAttribute(references.MemoryPackUnionAttribute);
}
public static bool HasDuplicate<T>(this IEnumerable<T> source)
{
var set = new HashSet<T>();
foreach (var item in source)
{
if (!set.Add(item))
{
return true;
}
}
return false;
}
public static IEnumerable<INamedTypeSymbol> GetAllBaseTypes(this INamedTypeSymbol symbol)
{
var t = symbol.BaseType;
while (t != null)
{
yield return t;
t = t.BaseType;
}
}
public static string FullyQualifiedToString(this ISymbol symbol)
{
return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
}
public static bool EqualsUnconstructedGenericType(this INamedTypeSymbol left, INamedTypeSymbol right)
{
var l = left.IsGenericType ? left.ConstructUnboundGenericType() : left;
var r = right.IsGenericType ? right.ConstructUnboundGenericType() : right;
return SymbolEqualityComparer.Default.Equals(l, r);
}
}