Skip to content

Commit

Permalink
Allow for specify return value on System.Linq.Enumerable.*OrDefault m…
Browse files Browse the repository at this point in the history
…ethods (dotnet#48886)

* Fix dotnet#20064

* Add API to ref assembly

* Make overloads with defaultValue not nullable

* Add unit tests, simplify implementation

* Add LastOrDefault tests

* Add Queryable tests

* Additional tests. Reformatting TryGet Methods.

* Update src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs

* Apply suggestions from code review

Co-authored-by: Eirik Tsarpalis <[email protected]>

* Fix ref methods

* Further adjust nullability

* Fix more nullables

* fix failing tests

* Restore coding style

Co-authored-by: Eirik Tsarpalis <[email protected]>
  • Loading branch information
Foxtrek64 and eiriktsarpalis authored Mar 18, 2021
1 parent 8170c06 commit 122c438
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ public static partial class Queryable
public static System.Linq.IQueryable<TSource> Except<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2) { throw null; }
public static System.Linq.IQueryable<TSource> Except<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static TSource? FirstOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource FirstOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, TSource defaultValue) { throw null; }
public static TSource? FirstOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static TSource FirstOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate, TSource defaultValue) { throw null; }
public static TSource First<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource First<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static System.Linq.IQueryable<System.Linq.IGrouping<TKey, TSource>> GroupBy<TSource, TKey>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector) { throw null; }
Expand All @@ -104,7 +106,9 @@ public static partial class Queryable
public static System.Linq.IQueryable<TResult> Join<TOuter, TInner, TKey, TResult>(this System.Linq.IQueryable<TOuter> outer, System.Collections.Generic.IEnumerable<TInner> inner, System.Linq.Expressions.Expression<System.Func<TOuter, TKey>> outerKeySelector, System.Linq.Expressions.Expression<System.Func<TInner, TKey>> innerKeySelector, System.Linq.Expressions.Expression<System.Func<TOuter, TInner, TResult>> resultSelector) { throw null; }
public static System.Linq.IQueryable<TResult> Join<TOuter, TInner, TKey, TResult>(this System.Linq.IQueryable<TOuter> outer, System.Collections.Generic.IEnumerable<TInner> inner, System.Linq.Expressions.Expression<System.Func<TOuter, TKey>> outerKeySelector, System.Linq.Expressions.Expression<System.Func<TInner, TKey>> innerKeySelector, System.Linq.Expressions.Expression<System.Func<TOuter, TInner, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
public static TSource? LastOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource LastOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, TSource defaultValue) { throw null; }
public static TSource? LastOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static TSource LastOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate, TSource defaultValue) { throw null; }
public static TSource Last<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource Last<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static long LongCount<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
Expand All @@ -129,7 +133,9 @@ public static partial class Queryable
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2) { throw null; }
public static bool SequenceEqual<TSource>(this System.Linq.IQueryable<TSource> source1, System.Collections.Generic.IEnumerable<TSource> source2, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, TSource defaultValue) { throw null; }
public static TSource? SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static TSource SingleOrDefault<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate, TSource defaultValue) { throw null; }
public static TSource Single<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static TSource Single<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static System.Linq.IQueryable<TSource> SkipLast<TSource>(this System.Linq.IQueryable<TSource> source, int count) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,20 @@ public static MethodInfo FirstOrDefault_TSource_2(Type TSource) =>
(s_FirstOrDefault_TSource_2 ??= new Func<IQueryable<object>, Expression<Func<object, bool>>, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())
.MakeGenericMethod(TSource);

private static MethodInfo? s_FirstOrDefault_TSource_3;

public static MethodInfo FirstOrDefault_TSource_3(Type TSource) =>
(s_FirstOrDefault_TSource_3 ??
(s_FirstOrDefault_TSource_3 = new Func<IQueryable<object>, object, object>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_FirstOrDefault_TSource_4;

public static MethodInfo FirstOrDefault_TSource_4(Type TSource) =>
(s_FirstOrDefault_TSource_4 ??
(s_FirstOrDefault_TSource_4 = new Func<IQueryable<object>, Expression<Func<object, bool>>, object, object>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_GroupBy_TSource_TKey_2;

public static MethodInfo GroupBy_TSource_TKey_2(Type TSource, Type TKey) =>
Expand Down Expand Up @@ -392,6 +406,20 @@ public static MethodInfo LastOrDefault_TSource_2(Type TSource) =>
(s_LastOrDefault_TSource_2 ??= new Func<IQueryable<object>, Expression<Func<object, bool>>, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())
.MakeGenericMethod(TSource);

private static MethodInfo? s_LastOrDefault_TSource_3;

public static MethodInfo LastOrDefault_TSource_3(Type TSource) =>
(s_LastOrDefault_TSource_3 ??
(s_LastOrDefault_TSource_3 = new Func<IQueryable<object>, object, object>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_LastOrDefault_TSource_4;

public static MethodInfo LastOrDefault_TSource_4(Type TSource) =>
(s_LastOrDefault_TSource_4 ??
(s_LastOrDefault_TSource_4 = new Func<IQueryable<object>, Expression<Func<object, bool>>, object, object>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_LongCount_TSource_1;

public static MethodInfo LongCount_TSource_1(Type TSource) =>
Expand Down Expand Up @@ -536,6 +564,20 @@ public static MethodInfo SingleOrDefault_TSource_2(Type TSource) =>
(s_SingleOrDefault_TSource_2 ??= new Func<IQueryable<object>, Expression<Func<object, bool>>, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())
.MakeGenericMethod(TSource);

private static MethodInfo? s_SingleOrDefault_TSource_3;

public static MethodInfo SingleOrDefault_TSource_3(Type TSource) =>
(s_SingleOrDefault_TSource_3 ??
(s_SingleOrDefault_TSource_3 = new Func<IQueryable<object>, object, object>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_SingleOrDefault_TSource_4;

public static MethodInfo SingleOrDefault_TSource_4(Type TSource) =>
(s_SingleOrDefault_TSource_4 ??
(s_SingleOrDefault_TSource_4 = new Func<IQueryable<object>, Expression<Func<object, bool>>, object, object>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition()))
.MakeGenericMethod(TSource);

private static MethodInfo? s_Skip_TSource_2;

public static MethodInfo Skip_TSource_2(Type TSource) =>
Expand Down
82 changes: 82 additions & 0 deletions src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,18 @@ public static TSource First<TSource>(this IQueryable<TSource> source, Expression
CachedReflectionInfo.FirstOrDefault_TSource_1(typeof(TSource)), source.Expression));
}

[DynamicDependency("FirstOrDefault`1", typeof(Enumerable))]
public static TSource FirstOrDefault<TSource>(this IQueryable<TSource> source, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.FirstOrDefault_TSource_3(typeof(TSource)),
source.Expression, Expression.Constant(defaultValue, typeof(TSource))));
}

[DynamicDependency("FirstOrDefault`1", typeof(Enumerable))]
public static TSource? FirstOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
Expand All @@ -879,6 +891,21 @@ public static TSource First<TSource>(this IQueryable<TSource> source, Expression
));
}

[DynamicDependency("FirstOrDefault`1", typeof(Enumerable))]
public static TSource FirstOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.FirstOrDefault_TSource_4(typeof(TSource)),
source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource))
));
}

[DynamicDependency("Last`1", typeof(Enumerable))]
public static TSource Last<TSource>(this IQueryable<TSource> source)
{
Expand Down Expand Up @@ -916,6 +943,18 @@ public static TSource Last<TSource>(this IQueryable<TSource> source, Expression<
CachedReflectionInfo.LastOrDefault_TSource_1(typeof(TSource)), source.Expression));
}

[DynamicDependency("LastOrDefault`1", typeof(Enumerable))]
public static TSource LastOrDefault<TSource>(this IQueryable<TSource> source, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.LastOrDefault_TSource_3(typeof(TSource)),
source.Expression, Expression.Constant(defaultValue, typeof(TSource))));
}

[DynamicDependency("LastOrDefault`1", typeof(Enumerable))]
public static TSource? LastOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
Expand All @@ -931,6 +970,21 @@ public static TSource Last<TSource>(this IQueryable<TSource> source, Expression<
));
}

[DynamicDependency("LastOrDefault`1", typeof(Enumerable))]
public static TSource LastOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.LastOrDefault_TSource_4(typeof(TSource)),
source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource))
));
}

[DynamicDependency("Single`1", typeof(Enumerable))]
public static TSource Single<TSource>(this IQueryable<TSource> source)
{
Expand Down Expand Up @@ -968,6 +1022,19 @@ public static TSource Single<TSource>(this IQueryable<TSource> source, Expressio
CachedReflectionInfo.SingleOrDefault_TSource_1(typeof(TSource)), source.Expression));
}

[DynamicDependency("SingleOrDefault`1", typeof(Enumerable))]
public static TSource SingleOrDefault<TSource>(this IQueryable<TSource> source, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.SingleOrDefault_TSource_3(typeof(TSource)),
source.Expression, Expression.Constant(defaultValue, typeof(TSource))));

}

[DynamicDependency("SingleOrDefault`1", typeof(Enumerable))]
public static TSource? SingleOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
Expand All @@ -983,6 +1050,21 @@ public static TSource Single<TSource>(this IQueryable<TSource> source, Expressio
));
}

[DynamicDependency("SingleOrDefault`1", typeof(Enumerable))]
public static TSource SingleOrDefault<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate, TSource defaultValue)
{
if (source == null)
throw Error.ArgumentNull(nameof(source));
if (predicate == null)
throw Error.ArgumentNull(nameof(predicate));
return source.Provider.Execute<TSource>(
Expression.Call(
null,
CachedReflectionInfo.SingleOrDefault_TSource_4(typeof(TSource)),
source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource))
));
}

[DynamicDependency("ElementAt`1", typeof(Enumerable))]
public static TSource ElementAt<TSource>(this IQueryable<TSource> source, int index)
{
Expand Down
46 changes: 46 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/FirstOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ public void Empty()
Assert.Equal(0, source.AsQueryable().FirstOrDefault());
}

[Fact]
public void EmptyDefault()
{
int[] source = { };
int defaultValue = 5;

Assert.Equal(defaultValue, source.AsQueryable().FirstOrDefault(defaultValue));
}

[Fact]
public void ManyElementsFirstIsDefault()
{
Expand All @@ -37,37 +46,74 @@ public void OneElementTruePredicate()
Assert.Equal(4, source.AsQueryable().FirstOrDefault(i => i % 2 == 0));
}

[Fact]
public void OneElementTruePredicateDefault()
{
int[] source = { 4 };
Assert.Equal(4, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5));
}

[Fact]
public void OneElementFalsePredicate()
{
int[] source = { 3 };
Assert.Equal(0, source.AsQueryable().FirstOrDefault(i => i % 2 == 0));
}

[Fact]
public void OneElementFalsePredicateDefault()
{
int[] source = { 3 };
Assert.Equal(5, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5));
}

[Fact]
public void ManyElementsPredicateFalseForAll()
{
int[] source = { 9, 5, 1, 3, 17, 21 };
Assert.Equal(0, source.AsQueryable().FirstOrDefault(i => i % 2 == 0));
}

[Fact]
public void ManyElementsPredicateFalseForAllDefault()
{
int[] source = { 9, 5, 1, 3, 17, 21 };
Assert.Equal(2, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 2));
}

[Fact]
public void PredicateTrueForSome()
{
int[] source = { 3, 7, 10, 7, 9, 2, 11, 17, 13, 8 };
Assert.Equal(10, source.AsQueryable().FirstOrDefault(i => i % 2 == 0));
}
[Fact]
public void PredicateTrueForSomeDefault()
{
int[] source = { 3, 7, 10, 7, 9, 2, 11, 17, 13, 8 };
Assert.Equal(10, source.AsQueryable().FirstOrDefault(i => i % 2 == 0, 5));
}

[Fact]
public void NullSource()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).FirstOrDefault());
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).FirstOrDefault(5));
}

[Fact]
public void NullSourcePredicateUsed()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).FirstOrDefault(i => i != 2));
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IQueryable<int>)null).FirstOrDefault(i => i != 2, 5));
}

[Fact]
public void NullPredicate()
{
Expression<Func<int, bool>> predicate = null;
AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).AsQueryable().FirstOrDefault(predicate));
AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).AsQueryable().FirstOrDefault(predicate, 5));
}

[Fact]
Expand Down
15 changes: 15 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/LastOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,28 @@ public void Empty()
Assert.Null(Enumerable.Empty<int?>().AsQueryable().LastOrDefault());
}

[Fact]
public void EmptyDefault()
{
int[] source = { };
int defaultValue = 5;
Assert.Equal(defaultValue, source.AsQueryable().LastOrDefault(defaultValue));
}

[Fact]
public void OneElement()
{
int[] source = { 5 };
Assert.Equal(5, source.AsQueryable().LastOrDefault());
}

[Fact]
public void OneElementFalsePredicate()
{
int[] source = { 3 };
Assert.Equal(5, source.AsQueryable().LastOrDefault(i => i % 2 == 0, 5));
}

[Fact]
public void ManyElementsLastIsDefault()
{
Expand Down
14 changes: 14 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/SingleOrDefaultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,26 @@ public void Empty()
Assert.Null(Enumerable.Empty<int?>().AsQueryable().SingleOrDefault());
}

[Fact]
public void EmptyDefault()
{
int[] source = { };
int defaultValue = 5;
Assert.Equal(defaultValue, source.AsQueryable().SingleOrDefault(5));
}

[Fact]
public void EmptySourceWithPredicate()
{
Assert.Null(Enumerable.Empty<int?>().AsQueryable().SingleOrDefault(i => i % 2 == 0));
}

[Fact]
public void EmptySourceWithPredicateDefault()
{
Assert.Equal(5, Enumerable.Empty<int?>().AsQueryable().SingleOrDefault(i => i % 2 == 0, 5));
}

[Theory]
[InlineData(1, 100)]
[InlineData(42, 100)]
Expand Down
Loading

0 comments on commit 122c438

Please sign in to comment.