Skip to content

Commit

Permalink
[CALCITE-1945] Make return types of AVG, VARIANCE, STDDEV and COVAR c…
Browse files Browse the repository at this point in the history
…ustomizable via RelDataTypeSystem

* Introduce VARIANCE and STDDEV as alias for _SAMP

Close apache#518
  • Loading branch information
minji-kim authored and julianhyde committed Aug 29, 2017
1 parent 6d2fc4e commit 4208d80
Show file tree
Hide file tree
Showing 13 changed files with 303 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,11 @@ public SqlNode toSql(RexProgram program, RexNode rex) {
}

final RexCall call = (RexCall) stripCastFromString(rex);
final SqlOperator op = call.getOperator();
SqlOperator op = call.getOperator();
switch (op.getKind()) {
case SUM0:
op = SqlStdOperatorTable.SUM;
}
final List<SqlNode> nodeList = toSql(program, call.getOperands());
switch (call.getKind()) {
case CAST:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
Expand Down Expand Up @@ -117,14 +116,27 @@ public void onMatch(RelOptRuleCall ruleCall) {
*/
private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
for (AggregateCall call : aggCallList) {
if (call.getAggregation() instanceof SqlAvgAggFunction
|| call.getAggregation() instanceof SqlSumAggFunction) {
if (isReducible(call.getAggregation().getKind())) {
return true;
}
}
return false;
}

/**
* Returns whether the aggregate call is a reducible function
*/
private boolean isReducible(final SqlKind kind) {
if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
return true;
}
switch (kind) {
case SUM:
return true;
}
return false;
}

/**
* Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
* the aggregates list to.
Expand Down Expand Up @@ -187,17 +199,16 @@ private RexNode reduceAgg(
List<AggregateCall> newCalls,
Map<AggregateCall, RexNode> aggCallMapping,
List<RexNode> inputExprs) {
if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
// replace original SUM(x) with
// case COUNT(x) when 0 then null else SUM0(x) end
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
}
if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
final SqlKind kind = oldCall.getAggregation().getKind();
final SqlKind kind = oldCall.getAggregation().getKind();
if (isReducible(kind)) {
switch (kind) {
case SUM:
// replace original SUM(x) with
// case COUNT(x) when 0 then null else SUM0(x) end
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
case STDDEV_POP:
// replace original STDDEV_POP(x) with
// SQRT(
Expand Down Expand Up @@ -243,19 +254,39 @@ private RexNode reduceAgg(
}
}

private AggregateCall createAggregateCallWithBinding(
RelDataTypeFactory typeFactory,
SqlAggFunction aggFunction,
RelDataType operandType,
Aggregate oldAggRel,
AggregateCall oldCall,
int argOrdinal) {
final Aggregate.AggCallBinding binding =
new Aggregate.AggCallBinding(typeFactory, aggFunction,
ImmutableList.of(operandType), oldAggRel.getGroupCount(),
oldCall.filterArg >= 0);
return AggregateCall.create(aggFunction,
oldCall.isDistinct(),
ImmutableIntList.of(argOrdinal),
oldCall.filterArg,
aggFunction.inferReturnType(binding),
null);
}

private RexNode reduceAvg(
Aggregate oldAggRel,
AggregateCall oldCall,
List<AggregateCall> newCalls,
Map<AggregateCall, RexNode> aggCallMapping) {
Map<AggregateCall, RexNode> aggCallMapping,
List<RexNode> inputExprs) {
final int nGroups = oldAggRel.getGroupCount();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType =
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
final int iAvgInput = oldCall.getArgList().get(0);
final RelDataType avgInputType =
getFieldType(
oldAggRel.getInput(),
iAvgInput);
AggregateCall sumCall =
final AggregateCall sumCall =
AggregateCall.create(
SqlStdOperatorTable.SUM,
oldCall.isDistinct(),
Expand All @@ -265,7 +296,7 @@ private RexNode reduceAvg(
oldAggRel.getInput(),
null,
null);
AggregateCall countCall =
final AggregateCall countCall =
AggregateCall.create(
SqlStdOperatorTable.COUNT,
oldCall.isDistinct(),
Expand All @@ -285,17 +316,20 @@ private RexNode reduceAvg(
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
RexNode denominatorRef =
final RexNode denominatorRef =
rexBuilder.addAggCall(countCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));

final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final RelDataType avgType = typeFactory.createTypeWithNullability(
oldCall.getType(), numeratorRef.getType().isNullable());
numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
final RexNode divideRef =
rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
numeratorRef,
denominatorRef);
rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(oldCall.getType(), divideRef);
}

Expand Down Expand Up @@ -381,36 +415,30 @@ private RexNode reduceStddev(

assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType =
getFieldType(
oldAggRel.getInput(),
argOrdinal);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType =
typeFactory.createTypeWithNullability(oldCall.getType(),
argOrdinalType.isNullable());

final RexNode argRef = inputExprs.get(argOrdinal);
final RexNode argSquared =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final RexNode argRef =
rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);

final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);

final Aggregate.AggCallBinding binding =
new Aggregate.AggCallBinding(typeFactory, SqlStdOperatorTable.SUM,
ImmutableList.of(argRef.getType()), oldAggRel.getGroupCount(),
oldCall.filterArg >= 0);
final AggregateCall sumArgSquaredAggCall =
AggregateCall.create(
SqlStdOperatorTable.SUM,
oldCall.isDistinct(),
ImmutableIntList.of(argSquaredOrdinal),
oldCall.filterArg,
SqlStdOperatorTable.SUM.inferReturnType(binding),
null);
createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM,
argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);

final RexNode sumArgSquared =
rexBuilder.addAggCall(sumArgSquaredAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
ImmutableList.of(sumArgSquaredAggCall.getType()));

final AggregateCall sumArgAggCall =
AggregateCall.create(
Expand All @@ -422,17 +450,18 @@ private RexNode reduceStddev(
oldAggRel.getInput(),
null,
null);

final RexNode sumArg =
rexBuilder.addAggCall(sumArgAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
ImmutableList.of(argType));

ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);

final AggregateCall countArgAggCall =
AggregateCall.create(
Expand All @@ -441,21 +470,21 @@ private RexNode reduceStddev(
oldCall.getArgList(),
oldCall.filterArg,
oldAggRel.getGroupCount(),
oldAggRel.getInput(),
oldAggRel,
null,
null);

final RexNode countArg =
rexBuilder.addAggCall(countArgAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
ImmutableList.of(argType));
ImmutableList.of(argOrdinalType));

final RexNode avgSumSquaredArg =
rexBuilder.makeCall(
SqlStdOperatorTable.DIVIDE,
sumSquaredArg, countArg);
SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);

final RexNode diff =
rexBuilder.makeCall(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,21 @@ public interface RelDataTypeSystem {
* 0 means "not applicable". */
int getNumTypeRadix(SqlTypeName typeName);

/**
* Returns the return type of a call to the {@code SUM} aggregate function
* inferred from its argument type.
/** Returns the return type of a call to the {@code SUM} aggregate function,
* inferred from its argument type. */
RelDataType deriveSumType(RelDataTypeFactory typeFactory,
RelDataType argumentType);

/** Returns the return type of a call to the {@code AVG}, {@code STDDEV} or
* {@code VAR} aggregate functions, inferred from its argument type.
*/
RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType argumentType);
RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
RelDataType argumentType);

/** Returns the return type of a call to the {@code COVAR} aggregate function,
* inferred from its argument types. */
RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
RelDataType arg0Type, RelDataType arg1Type);

/** Returns the return type of the {@code CUME_DIST} and {@code PERCENT_RANK}
* aggregate functions. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,21 @@ && getDefaultPrecision(typeName) != -1) {
return 0;
}

@Override public RelDataType deriveSumType(
RelDataTypeFactory typeFactory, RelDataType argumentType) {
@Override public RelDataType deriveSumType(RelDataTypeFactory typeFactory,
RelDataType argumentType) {
return argumentType;
}

@Override public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
RelDataType argumentType) {
return argumentType;
}

@Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
RelDataType arg0Type, RelDataType arg1Type) {
return arg0Type;
}

@Override public RelDataType deriveFractionalRankType(RelDataTypeFactory typeFactory) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.DOUBLE), false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,10 @@ public static double power(double b0, double b1) {
return Math.pow(b0, b1);
}

public static double power(double b0, BigDecimal b1) {
return Math.pow(b0, b1.doubleValue());
}

public static double power(long b0, long b1) {
return Math.pow(b0, b1);
}
Expand Down
9 changes: 9 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/SqlKind.java
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,15 @@ public enum SqlKind {
public static final Set<SqlKind> FUNCTION =
EnumSet.of(OTHER_FUNCTION, ROW, TRIM, LTRIM, RTRIM, CAST, JDBC_FN);

/**
* Category of SqlAvgAggFunction.
*
* <p>Consists of {@link #AVG}, {@link #STDDEV_POP}, {@link #STDDEV_SAMP},
* {@link #VAR_POP}, {@link #VAR_SAMP}.
*/
public static final Set<SqlKind> AVG_AGG_FUNCTIONS =
EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP);

/**
* Category of comparison operators.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,27 @@
* double</code>), and the result is the same type.
*/
public class SqlAvgAggFunction extends SqlAggFunction {

//~ Constructors -----------------------------------------------------------

/**
* Creates a SqlAvgAggFunction.
*/
public SqlAvgAggFunction(SqlKind kind) {
super(kind.name(),
this(kind.name(), kind);
}

SqlAvgAggFunction(String name, SqlKind kind) {
super(name,
null,
kind,
ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
ReturnTypes.AVG_AGG_FUNCTION,
null,
OperandTypes.NUMERIC,
SqlFunctionCategory.NUMERIC,
false,
false);
Preconditions.checkArgument(kind == SqlKind.AVG
|| kind == SqlKind.STDDEV_POP
|| kind == SqlKind.STDDEV_SAMP
|| kind == SqlKind.VAR_POP
|| kind == SqlKind.VAR_SAMP);
Preconditions.checkArgument(SqlKind.AVG_AGG_FUNCTIONS.contains(kind), "unsupported sql kind");
}

@Deprecated // to be removed before 2.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public SqlCovarAggFunction(SqlKind kind) {
super(kind.name(),
null,
kind,
ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
ReturnTypes.COVAR_FUNCTION,
null,
OperandTypes.NUMERIC_NUMERIC,
SqlFunctionCategory.NUMERIC,
Expand Down
Loading

0 comments on commit 4208d80

Please sign in to comment.