Skip to content

Commit

Permalink
Nullable annotate System.Linq.Queryable (dotnet#979)
Browse files Browse the repository at this point in the history
* Nullable annotate System.Linq.Queryable

* update annotations based on System.Linq.Expressions annotation

* address feedback

* revert System.Linq.Expressions reference source

* make IQueryable LINQ methods compatible with annotated IEnumerable signatures

* remove unnecessary nullable annotation

* address feedback
  • Loading branch information
eiriktsarpalis authored Jan 15, 2020
1 parent c8ca0f3 commit d4d38b6
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public partial interface IQueryProvider
{
System.Linq.IQueryable CreateQuery(System.Linq.Expressions.Expression expression);
System.Linq.IQueryable<TElement> CreateQuery<TElement>(System.Linq.Expressions.Expression expression);
object Execute(System.Linq.Expressions.Expression expression);
object? Execute(System.Linq.Expressions.Expression expression);
TResult Execute<TResult>(System.Linq.Expressions.Expression expression);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public interface IQueryProvider
/// <remarks>
/// The <see cref="Execute"/> method executes queries that return a single value (instead of an enumerable sequence of values). Expression trees that represent queries that return enumerable results are executed when their associated <see cref="IQueryable"/> object is enumerated.
/// </remarks>
object Execute(Expression expression);
object? Execute(Expression expression);

/// <summary>
/// Executes the strongly-typed query represented by a specified expression tree.
Expand Down
49 changes: 30 additions & 19 deletions src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Configurations>$(NetCoreAppCurrent)-Debug;$(NetCoreAppCurrent)-Release</Configurations>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Compile Include="System.Linq.Queryable.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
<AssemblyName>System.Linq.Queryable</AssemblyName>
<RootNamespace>System.Linq.Queryable</RootNamespace>
<Configurations>$(NetCoreAppCurrent)-Debug;$(NetCoreAppCurrent)-Release</Configurations>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Compile Include="System\Linq\CachedReflection.cs" />
Expand All @@ -22,4 +23,4 @@
<Reference Include="System.Resources.ResourceManager" />
<Reference Include="System.Runtime" />
</ItemGroup>
</Project>
</Project>
252 changes: 126 additions & 126 deletions src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ namespace System.Linq
{
public abstract class EnumerableExecutor
{
internal abstract object ExecuteBoxed();
internal abstract object? ExecuteBoxed();

internal EnumerableExecutor() { }

internal static EnumerableExecutor Create(Expression expression)
{
Type execType = typeof(EnumerableExecutor<>).MakeGenericType(expression.Type);
return (EnumerableExecutor)Activator.CreateInstance(execType, expression);
return (EnumerableExecutor)Activator.CreateInstance(execType, expression)!;
}
}

Expand All @@ -29,13 +29,13 @@ public EnumerableExecutor(Expression expression)
_expression = expression;
}

internal override object ExecuteBoxed() => Execute();
internal override object? ExecuteBoxed() => Execute();

internal T Execute()
{
EnumerableRewriter rewriter = new EnumerableRewriter();
Expression body = rewriter.Visit(_expression);
Expression<Func<T>> f = Expression.Lambda<Func<T>>(body, (IEnumerable<ParameterExpression>)null);
Expression<Func<T>> f = Expression.Lambda<Func<T>>(body, (IEnumerable<ParameterExpression>?)null);
Func<T> func = f.Compile();
return func();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@ namespace System.Linq
public abstract class EnumerableQuery
{
internal abstract Expression Expression { get; }
internal abstract IEnumerable Enumerable { get; }
internal abstract IEnumerable? Enumerable { get; }

internal EnumerableQuery() { }

internal static IQueryable Create(Type elementType, IEnumerable sequence)
{
Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
return (IQueryable)Activator.CreateInstance(seqType, sequence);
return (IQueryable)Activator.CreateInstance(seqType, sequence)!;
}

internal static IQueryable Create(Type elementType, Expression expression)
{
Type seqType = typeof(EnumerableQuery<>).MakeGenericType(elementType);
return (IQueryable)Activator.CreateInstance(seqType, expression);
return (IQueryable)Activator.CreateInstance(seqType, expression)!;
}
}

public class EnumerableQuery<T> : EnumerableQuery, IOrderedQueryable<T>, IQueryProvider
{
private readonly Expression _expression;
private IEnumerable<T> _enumerable;
private IEnumerable<T>? _enumerable;

IQueryProvider IQueryable.Provider => this;

Expand All @@ -48,7 +48,7 @@ public EnumerableQuery(Expression expression)

internal override Expression Expression => _expression;

internal override IEnumerable Enumerable => _enumerable;
internal override IEnumerable? Enumerable => _enumerable;

Expression IQueryable.Expression => _expression;

Expand All @@ -58,7 +58,7 @@ IQueryable IQueryProvider.CreateQuery(Expression expression)
{
if (expression == null)
throw Error.ArgumentNull(nameof(expression));
Type iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
Type? iqType = TypeHelper.FindGenericType(typeof(IQueryable<>), expression.Type);
if (iqType == null)
throw Error.ArgumentNotValid(nameof(expression));
return Create(iqType.GetGenericArguments()[0], expression);
Expand All @@ -75,7 +75,7 @@ IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
return new EnumerableQuery<TElement>(expression);
}

object IQueryProvider.Execute(Expression expression)
object? IQueryProvider.Execute(Expression expression)
{
if (expression == null)
throw Error.ArgumentNull(nameof(expression));
Expand All @@ -101,7 +101,7 @@ private IEnumerator<T> GetEnumerator()
{
EnumerableRewriter rewriter = new EnumerableRewriter();
Expression body = rewriter.Visit(_expression);
Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>)null);
Expression<Func<IEnumerable<T>>> f = Expression.Lambda<Func<IEnumerable<T>>>(body, (IEnumerable<ParameterExpression>?)null);
IEnumerable<T> enumerable = f.Compile()();
if (enumerable == this)
throw Error.EnumeratingNullEnumerableExpression();
Expand All @@ -110,10 +110,9 @@ private IEnumerator<T> GetEnumerator()
return _enumerable.GetEnumerator();
}

public override string ToString()
public override string? ToString()
{
ConstantExpression c = _expression as ConstantExpression;
if (c != null && c.Value == this)
if (_expression is ConstantExpression c && c.Value == this)
{
if (_enumerable != null)
return _enumerable.ToString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;

Expand All @@ -15,22 +16,22 @@ internal class EnumerableRewriter : ExpressionVisitor
{
// We must ensure that if a LabelTarget is rewritten that it is always rewritten to the same new target
// or otherwise expressions using it won't match correctly.
private Dictionary<LabelTarget, LabelTarget> _targetCache;
private Dictionary<LabelTarget, LabelTarget>? _targetCache;
// Finding equivalent types can be relatively expensive, and hitting with the same types repeatedly is quite likely.
private Dictionary<Type, Type> _equivalentTypeCache;
private Dictionary<Type, Type>? _equivalentTypeCache;

protected override Expression VisitMethodCall(MethodCallExpression m)
{
Expression obj = Visit(m.Object);
Expression? obj = Visit(m.Object);
ReadOnlyCollection<Expression> args = Visit(m.Arguments);

// check for args changed
if (obj != m.Object || args != m.Arguments)
{
MethodInfo mInfo = m.Method;
Type[] typeArgs = (mInfo.IsGenericMethod) ? mInfo.GetGenericArguments() : null;
Type[]? typeArgs = (mInfo.IsGenericMethod) ? mInfo.GetGenericArguments() : null;

if ((mInfo.IsStatic || mInfo.DeclaringType.IsAssignableFrom(obj.Type))
if ((mInfo.IsStatic || mInfo.DeclaringType!.IsAssignableFrom(obj!.Type))
&& ArgsMatch(mInfo, args, typeArgs))
{
// current method is still valid
Expand All @@ -46,7 +47,7 @@ protected override Expression VisitMethodCall(MethodCallExpression m)
else
{
// rebind to new method
MethodInfo method = FindMethod(mInfo.DeclaringType, mInfo.Name, args, typeArgs);
MethodInfo method = FindMethod(mInfo.DeclaringType!, mInfo.Name, args, typeArgs);
args = FixupQuotedArgs(method, args);
return Expression.Call(obj, method, args);
}
Expand All @@ -59,7 +60,7 @@ private ReadOnlyCollection<Expression> FixupQuotedArgs(MethodInfo mi, ReadOnlyCo
ParameterInfo[] pis = mi.GetParameters();
if (pis.Length > 0)
{
List<Expression> newArgs = null;
List<Expression>? newArgs = null;
for (int i = 0, n = pis.Length; i < n; i++)
{
Expression arg = argList[i];
Expand Down Expand Up @@ -98,7 +99,7 @@ private Expression FixupQuotedExpression(Type type, Expression expression)
Type strippedType = StripExpression(expr.Type);
if (type.IsAssignableFrom(strippedType))
{
Type elementType = type.GetElementType();
Type elementType = type.GetElementType()!;
NewArrayExpression na = (NewArrayExpression)expr;
List<Expression> exprs = new List<Expression>(na.Expressions.Count);
for (int i = 0, n = na.Expressions.Count; i < n; i++)
Expand Down Expand Up @@ -136,7 +137,7 @@ private static Type GetPublicType(Type t)

private Type GetEquivalentType(Type type)
{
Type equiv;
Type? equiv;
if (_equivalentTypeCache == null)
{
// Pre-loading with the non-generic IQueryable and IEnumerable not only covers this case
Expand Down Expand Up @@ -170,7 +171,7 @@ private Type GetEquivalentType(Type type)
.Where(i => i.IsGenericType && i.GenericTypeArguments.Length == 1)
.Select(i => new { Info = i, GenType = i.GetGenericTypeDefinition() })
.ToArray();
Type typeArg = singleTypeGenInterfacesWithGetType
Type? typeArg = singleTypeGenInterfacesWithGetType
.Where(i => i.GenType == typeof(IOrderedQueryable<>) || i.GenType == typeof(IOrderedEnumerable<>))
.Select(i => i.Info.GenericTypeArguments[0])
.Distinct()
Expand All @@ -194,8 +195,7 @@ private Type GetEquivalentType(Type type)

protected override Expression VisitConstant(ConstantExpression c)
{
EnumerableQuery sq = c.Value as EnumerableQuery;
if (sq != null)
if (c.Value is EnumerableQuery sq)
{
if (sq.Enumerable != null)
{
Expand All @@ -211,21 +211,21 @@ protected override Expression VisitConstant(ConstantExpression c)



private static ILookup<string, MethodInfo> s_seqMethods;
private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[] typeArgs)
private static ILookup<string, MethodInfo>? s_seqMethods;
private static MethodInfo FindEnumerableMethod(string name, ReadOnlyCollection<Expression> args, params Type[]? typeArgs)
{
if (s_seqMethods == null)
{
s_seqMethods = typeof(Enumerable).GetStaticMethods().ToLookup(m => m.Name);
}
MethodInfo mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
MethodInfo? mi = s_seqMethods[name].FirstOrDefault(m => ArgsMatch(m, args, typeArgs));
Debug.Assert(mi != null, "All static methods with arguments on Queryable have equivalents on Enumerable.");
if (typeArgs != null)
return mi.MakeGenericMethod(typeArgs);
return mi;
}

private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[] typeArgs)
private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
{
using (IEnumerator<MethodInfo> en = type.GetStaticMethods().Where(m => m.Name == name).GetEnumerator())
{
Expand All @@ -241,7 +241,7 @@ private static MethodInfo FindMethod(Type type, string name, ReadOnlyCollection<
throw Error.NoMethodOnTypeMatchingArguments(name, type);
}

private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[] typeArgs)
private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args, Type[]? typeArgs)
{
ParameterInfo[] mParams = m.GetParameters();
if (mParams.Length != args.Count)
Expand Down Expand Up @@ -269,7 +269,7 @@ private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args,
if (parameterType == null)
return false;
if (parameterType.IsByRef)
parameterType = parameterType.GetElementType();
parameterType = parameterType.GetElementType()!;
Expression arg = args[i];
if (!parameterType.IsAssignableFrom(arg.Type))
{
Expand All @@ -290,8 +290,8 @@ private static bool ArgsMatch(MethodInfo m, ReadOnlyCollection<Expression> args,
private static Type StripExpression(Type type)
{
bool isArray = type.IsArray;
Type tmp = isArray ? type.GetElementType() : type;
Type eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
Type tmp = isArray ? type.GetElementType()! : type;
Type? eType = TypeHelper.FindGenericType(typeof(Expression<>), tmp);
if (eType != null)
tmp = eType.GetGenericArguments()[0];
if (isArray)
Expand Down Expand Up @@ -333,22 +333,22 @@ protected override Expression VisitBlock(BlockExpression node)

protected override Expression VisitGoto(GotoExpression node)
{
Type type = node.Value.Type;
Type type = node.Value!.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
return base.VisitGoto(node);
LabelTarget target = VisitLabelTarget(node.Target);
Expression value = Visit(node.Value);
return Expression.MakeGoto(node.Kind, target, value, GetEquivalentType(typeof(EnumerableQuery).IsAssignableFrom(type) ? value.Type : type));
}

protected override LabelTarget VisitLabelTarget(LabelTarget node)
protected override LabelTarget VisitLabelTarget(LabelTarget? node)
{
LabelTarget newTarget;
LabelTarget? newTarget;
if (_targetCache == null)
_targetCache = new Dictionary<LabelTarget, LabelTarget>();
else if (_targetCache.TryGetValue(node, out newTarget))
else if (_targetCache.TryGetValue(node!, out newTarget))
return newTarget;
Type type = node.Type;
Type type = node!.Type;
if (!typeof(IQueryable).IsAssignableFrom(type))
newTarget = base.VisitLabelTarget(node);
else
Expand Down
Loading

0 comments on commit d4d38b6

Please sign in to comment.