Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix column mapping and logical condition parsing in VisitBinary, refactor for modularity #110

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 191 additions & 144 deletions Postgrest/Linq/WhereExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,213 @@ internal class WhereExpressionVisitor : ExpressionVisitor

/// <summary>
/// An entry point that will be used to populate <see cref="Filter"/>.
/// This method handles comparisons, logical operations, and simple arithmetic expressions in a Where clause.
///
/// Invoked like:
/// `Table&lt;Movies&gt;().Where(x => x.Name == "Top Gun").Get();`
/// Examples:
/// <code>Table&lt;Movies&gt;().Where(x => x.Name == "Top Gun").Get();</code>
/// <code>Table&lt;Movies&gt;().Where(x => x.Rating > 5 &amp;&amp; x.Year == 1986).Get();</code>
/// <code>Table&lt;Movies&gt;().Where(x => x.Rating >= maxRating - 1).Get();</code>
/// </summary>
/// <param name="node"></param>
/// <returns></returns>
/// <exception cref="ArgumentException"></exception>
/// <param name="node">The binary expression to process, such as a comparison (e.g., x.Name == "Top Gun") or logical operation (e.g., x.Rating > 5 && x.Year == 1986).</param>
/// <returns>The processed expression, typically the input <paramref name="node"/>.</returns>
/// <exception cref="ArgumentException">Thrown if the left side of the expression does not correspond to a property with a <see cref="ColumnAttribute"/> or <see cref="PrimaryKeyAttribute"/>.</exception>
/// <exception cref="NotSupportedException">Thrown if the right side of the expression cannot be evaluated to a constant value.</exception>
/// <exception cref="InvalidOperationException">Thrown if the <see cref="Filter"/> is not set after processing the expression.</exception>
protected override Expression VisitBinary(BinaryExpression node)
{
var op = GetMappedOperator(node);

// In the event this is a nested expression (n.Name == "Example" || n.Id = 3)
switch (node.NodeType)
// Handle logical operations (e.g., x.Rating > 5 && x.Year == 1986)
if (IsLogicalOperation(node.NodeType))
{
case ExpressionType.And:
case ExpressionType.Or:
case ExpressionType.AndAlso:
case ExpressionType.OrElse:
var leftVisitor = new WhereExpressionVisitor();
leftVisitor.Visit(node.Left);
var conditions = FlattenLogicalConditions(node, op);
Filter = new QueryFilter(op, conditions);
return node;
}

// Handle simple comparisons (e.g., x.Name == "Top Gun" or x.Rating >= maxRating - 1)
var column = ExtractColumnName(node.Left);
var rightValue = EvaluateRightExpression(node.Right);

// Define the filter for a simple comparison
Filter = new QueryFilter(column, op, rightValue);
return node;
}

var rightVisitor = new WhereExpressionVisitor();
rightVisitor.Visit(node.Right);
/// <summary>
/// Flattens a tree of logical conditions (e.g., AND, OR) into a single list of conditions at the same level.
/// </summary>
/// <param name="node">The binary expression node representing a logical operation.</param>
/// <param name="op">The operator (e.g., AND, OR) for the logical operation.</param>
/// <returns>A list of filters representing all conditions at the same level.</returns>
private List<IPostgrestQueryFilter> FlattenLogicalConditions(BinaryExpression node, Operator op)
{
var conditions = new List<IPostgrestQueryFilter>();

// Recursively flatten the left and right sides
FlattenLogicalConditionsRecursive(node, op, conditions);

Filter = new QueryFilter(op,
new List<IPostgrestQueryFilter> { leftVisitor.Filter!, rightVisitor.Filter! });
return conditions;
}

return node;
/// <summary>
/// Recursively flattens a tree of logical conditions into a list of filters.
/// </summary>
/// <param name="node">The current binary expression node.</param>
/// <param name="op">The operator (e.g., AND, OR) for the logical operation.</param>
/// <param name="conditions">The list to accumulate the flattened conditions.</param>
private void FlattenLogicalConditionsRecursive(BinaryExpression node, Operator op, List<IPostgrestQueryFilter> conditions)
{
// If the node is a logical operation with the same operator, recurse into its children
if (IsLogicalOperation(node.NodeType) && GetMappedOperator(node) == op)
{
if (node.Left is BinaryExpression leftBinary)
{
FlattenLogicalConditionsRecursive(leftBinary, op, conditions);
}
else
{
conditions.Add(ProcessSubExpression(node.Left));
}

if (node.Right is BinaryExpression rightBinary)
{
FlattenLogicalConditionsRecursive(rightBinary, op, conditions);
}
else
{
conditions.Add(ProcessSubExpression(node.Right));
}
}
else
{
// If the node is not a logical operation (or has a different operator), process it as a single condition
conditions.Add(ProcessSubExpression(node));
}
}

// Otherwise, the base case.
/// <summary>
/// Determines if the node type represents a logical operation (AND, OR).
/// </summary>
/// <param name="nodeType">The type of the expression node.</param>
/// <returns>True if the node type is a logical operation; otherwise, false.</returns>
private static bool IsLogicalOperation(ExpressionType nodeType)
{
return nodeType == ExpressionType.And ||
nodeType == ExpressionType.Or ||
nodeType == ExpressionType.AndAlso ||
nodeType == ExpressionType.OrElse;
}

var left = Visit(node.Left);
var right = Visit(node.Right);
/// <summary>
/// Processes a subexpression and returns the resulting filter.
/// </summary>
/// <param name="expression">The subexpression to process.</param>
/// <returns>The filter generated by the subexpression.</returns>
/// <exception cref="InvalidOperationException">Thrown if the subexpression does not produce a valid filter.</exception>
private IPostgrestQueryFilter ProcessSubExpression(Expression expression)
{
var visitor = new WhereExpressionVisitor();
visitor.Visit(expression);
return visitor.Filter ?? throw new InvalidOperationException($"Subexpression '{expression}' did not produce a valid filter.");
}

string? column = null;
/// <summary>
/// Extracts the column name from the left side of a binary expression.
/// </summary>
/// <param name="left">The left side expression, expected to be a property access.</param>
/// <returns>The column name corresponding to the property.</returns>
/// <exception cref="ArgumentException">Thrown if the left side does not correspond to a property with a <see cref="ColumnAttribute"/> or <see cref="PrimaryKeyAttribute"/>.</exception>
private string ExtractColumnName(Expression left)
{
if (left is MemberExpression leftMember)
{
column = GetColumnFromMemberExpression(leftMember);
} //To handle properly if it's a Convert ExpressionType generally with nullable properties
else if (left is UnaryExpression leftUnary && leftUnary.NodeType == ExpressionType.Convert &&
leftUnary.Operand is MemberExpression leftOperandMember)
return GetColumnFromMemberExpression(leftMember);
}
if (left is UnaryExpression leftUnary && leftUnary.NodeType == ExpressionType.Convert &&
leftUnary.Operand is MemberExpression leftOperandMember)
{
column = GetColumnFromMemberExpression(leftOperandMember);
return GetColumnFromMemberExpression(leftOperandMember);
}

if (column == null)
throw new ArgumentException(
$"Left side of expression: '{node}' is expected to be property with a ColumnAttribute or PrimaryKeyAttribute");
throw new ArgumentException(
$"Left side of expression: '{left}' is expected to be a property with a ColumnAttribute or PrimaryKeyAttribute");
}

/// <summary>
/// Evaluates the right side of a binary expression to produce a constant value, applying special handling for certain types.
/// </summary>
/// <param name="right">The right side expression to evaluate.</param>
/// <returns>The evaluated value of the expression, formatted appropriately for use in a PostgREST query.</returns>
/// <exception cref="NotSupportedException">Thrown if the right side cannot be evaluated to a constant value.</exception>
private object EvaluateRightExpression(Expression right)
{
right = Visit(right); // Process the right expression

object value = right switch
{
ConstantExpression constant => constant.Value,
MemberExpression member => EvaluateExpression(member),
NewExpression newExpr => EvaluateExpression(newExpr),
UnaryExpression unary => EvaluateExpression(unary),
BinaryExpression binary => EvaluateBinaryExpression(binary) ?? throw new NotSupportedException(
$"Binary expression '{binary}' on the right side is not supported. Only constant values or simple expressions are allowed."),
_ => throw new NotSupportedException(
$"Right side of expression: '{right}' is not supported. Expected a constant, member, new, unary, or simple binary expression.")
};

if (right is ConstantExpression rightConstant)
return value switch
{
HandleConstantExpression(column, op, rightConstant);
DateTime dateTime => dateTime,
DateTimeOffset dateTimeOffset => dateTimeOffset,
Guid guid => guid.ToString(),
Enum enumValue => enumValue,
_ => value
};
}

/// <summary>
/// Evaluates an expression to produce a constant value.
/// </summary>
/// <typeparam name="TExpression">The type of the expression to evaluate (e.g., MemberExpression, NewExpression, UnaryExpression).</typeparam>
/// <param name="expression">The expression to evaluate.</param>
/// <returns>The evaluated value of the expression.</returns>
/// <exception cref="InvalidOperationException">Thrown if the expression cannot be evaluated.</exception>
private object EvaluateExpression<TExpression>(TExpression expression) where TExpression : Expression
{
try
{
var lambda = Expression.Lambda(expression);
var compiled = lambda.Compile();
return compiled.DynamicInvoke();
}
else if (right is MemberExpression memberExpression)
catch (Exception ex)
{
HandleMemberExpression(column, op, memberExpression);
throw new InvalidOperationException($"Failed to evaluate {typeof(TExpression).Name.ToLower()}: '{expression}'.", ex);
}
else if (right is NewExpression newExpression)
}

/// <summary>
/// Evaluates a binary expression to compute its constant value, if possible.
/// </summary>
/// <param name="binaryExpression">The binary expression to evaluate (e.g., 'x - 5').</param>
/// <returns>The computed value of the expression as an object, or null if the expression cannot be evaluated.</returns>
/// <remarks>
/// Returns null if the expression cannot be evaluated due to unresolved variables or invalid operations.
/// The calling code should handle the null return value appropriately.
/// </remarks>
private object? EvaluateBinaryExpression(BinaryExpression binaryExpression)
{
try
{
HandleNewExpression(column, op, newExpression);
var lambda = Expression.Lambda(binaryExpression);
var compiled = lambda.Compile();
return compiled.DynamicInvoke();
}
else if (right is UnaryExpression unaryExpression)
catch (Exception)
{
HandleUnaryExpression(column, op, unaryExpression);
return null;
}

return node;
}

/// <summary>
Expand Down Expand Up @@ -135,100 +274,6 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
return node;
}

/// <summary>
/// A constant expression parser (i.e. x => x.Id == 5 &lt;- where '5' is the constant)
/// </summary>
/// <param name="column"></param>
/// <param name="op"></param>
/// <param name="constantExpression"></param>
private void HandleConstantExpression(string column, Operator op, ConstantExpression constantExpression)
{
if (constantExpression.Type.IsEnum)
{
var enumValue = constantExpression.Value;
Filter = new QueryFilter(column, op, enumValue);
}
else
{
Filter = new QueryFilter(column, op, constantExpression.Value);
}
}

/// <summary>
/// A member expression parser (i.e. => x.Id == Example.Id &lt;- where both `x.Id` and `Example.Id` are parsed as 'members')
/// </summary>
/// <param name="column"></param>
/// <param name="op"></param>
/// <param name="memberExpression"></param>
private void HandleMemberExpression(string column, Operator op, MemberExpression memberExpression)
{
Filter = new QueryFilter(column, op, GetMemberExpressionValue(memberExpression));
}

/// <summary>
/// A unary expression parser (i.e. => x.Id == 1 &lt;- where both `1` is considered unary)
/// </summary>
/// <param name="column"></param>
/// <param name="op"></param>
/// <param name="unaryExpression"></param>
private void HandleUnaryExpression(string column, Operator op, UnaryExpression unaryExpression)
{
if (unaryExpression.Operand is ConstantExpression constantExpression)
{
HandleConstantExpression(column, op, constantExpression);
}
else if (unaryExpression.Operand is MemberExpression memberExpression)
{
HandleMemberExpression(column, op, memberExpression);
}
else if (unaryExpression.Operand is NewExpression newExpression)
{
HandleNewExpression(column, op, newExpression);
}
}

/// <summary>
/// An instantiated class parser (i.e. x => x.CreatedAt &lt;= new DateTime(2022, 08, 20) &lt;- where `new DateTime(...)` is an instantiated expression.
/// </summary>
/// <param name="column"></param>
/// <param name="op"></param>
/// <param name="newExpression"></param>
private void HandleNewExpression(string column, Operator op, NewExpression newExpression)
{
var argumentValues = new List<object>();
foreach (var argument in newExpression.Arguments)
{
var lambda = Expression.Lambda(argument);
var func = lambda.Compile();
argumentValues.Add(func.DynamicInvoke());
}

var constructor = newExpression.Constructor;
var instance = constructor.Invoke(argumentValues.ToArray());

switch (instance)
{
case DateTime dateTime:
Filter = new QueryFilter(column, op, dateTime);
break;
case DateTimeOffset dateTimeOffset:
Filter = new QueryFilter(column, op, dateTimeOffset);
break;
case Guid guid:
Filter = new QueryFilter(column, op, guid.ToString());
break;
default:
{
if (instance.GetType().IsEnum)
{
Filter = new QueryFilter(column, op, instance);
}

break;
}
}
}

/// <summary>
/// Gets a column name (postgrest) from a Member Expression (used on BaseModel)
/// </summary>
Expand All @@ -238,19 +283,21 @@ private string GetColumnFromMemberExpression(MemberExpression node)
{
var type = node.Member.ReflectedType;
var prop = type?.GetProperty(node.Member.Name);
var attrs = prop?.GetCustomAttributes(true);
if (prop == null)
{
return node.Member.Name;
}

if (attrs == null) return node.Member.Name;
var columnAttr = prop.GetCustomAttribute<ColumnAttribute>(true);
if (columnAttr != null)
{
return columnAttr.ColumnName;
}

foreach (var attr in attrs)
var primaryKeyAttr = prop.GetCustomAttribute<PrimaryKeyAttribute>(true);
if (primaryKeyAttr != null)
{
switch (attr)
{
case ColumnAttribute columnAttr:
return columnAttr.ColumnName;
case PrimaryKeyAttribute primaryKeyAttr:
return primaryKeyAttr.ColumnName;
}
return primaryKeyAttr.ColumnName;
}

return node.Member.Name;
Expand Down