Skip to content

Commit

Permalink
[FLINK-33563] Implement type inference for Agg functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dawidwys committed Nov 16, 2023
1 parent 596864d commit 2e5face
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -672,14 +672,16 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
BuiltInFunctionDefinition.newBuilder()
.name("collect")
.kind(AGGREGATE)
.outputTypeStrategy(TypeStrategies.MISSING)
.inputTypeStrategy(sequence(ANY))
.outputTypeStrategy(SpecificTypeStrategies.COLLECT)
.build();

public static final BuiltInFunctionDefinition DISTINCT =
BuiltInFunctionDefinition.newBuilder()
.name("distinct")
.kind(AGGREGATE)
.outputTypeStrategy(TypeStrategies.MISSING)
.inputTypeStrategy(sequence(ANY))
.outputTypeStrategy(argument(0))
.build();

// --------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.TypeStrategy;

import java.util.List;
import java.util.Optional;

/**
* Type strategy that returns a {@link DataTypes#MULTISET(DataType)} with element type equal to the
* type of the first argument.
*/
@Internal
public class CollectTypeStrategy implements TypeStrategy {

@Override
public Optional<DataType> inferType(CallContext callContext) {
List<DataType> argumentDataTypes = callContext.getArgumentDataTypes();
if (argumentDataTypes.size() != 1) {
return Optional.empty();
}

return Optional.of(DataTypes.MULTISET(argumentDataTypes.get(0)).notNull());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public final class SpecificTypeStrategies {
/** See {@link MapTypeStrategy}. */
public static final TypeStrategy MAP = new MapTypeStrategy();

/** See {@link CollectTypeStrategy}. */
public static final TypeStrategy COLLECT = new CollectTypeStrategy();

/** See {@link IfNullTypeStrategy}. */
public static final TypeStrategy IF_NULL = new IfNullTypeStrategy();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.types.inference.strategies;

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.types.inference.TypeStrategiesTestBase;

import java.util.stream.Stream;

/** Tests for {@link CollectTypeStrategy}. */
class CollectTypeStrategyTest extends TypeStrategiesTestBase {

@Override
protected Stream<TestSpec> testData() {
return Stream.of(
TestSpec.forStrategy("Infer a collect type", SpecificTypeStrategies.COLLECT)
.inputTypes(DataTypes.BIGINT())
.expectDataType(DataTypes.MULTISET(DataTypes.BIGINT()).notNull()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,6 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp
case fd: FunctionDefinition =>
fd match {

case DISTINCT =>
assert(args.size == 1)
DistinctAgg(args.head)

case COLLECT =>
assert(args.size == 1)
Collect(args.head)

case ORDER_ASC =>
assert(args.size == 1)
Asc(args.head)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,6 @@ case class ApiResolvedAggregateCallExpression(resolvedCall: CallExpression) exte
.fromDataTypeToLegacyInfo(resolvedCall.getOutputDataType)
}

case class DistinctAgg(child: PlannerExpression) extends Aggregation {

def distinct: PlannerExpression = DistinctAgg(child)

override private[flink] def resultType: TypeInformation[_] = child.resultType

override private[flink] def validateInput(): ValidationResult = {
super.validateInput()
child match {
case agg: Aggregation =>
child.validateInput()
case _ =>
ValidationFailure(
s"Distinct modifier cannot be applied to $child! " +
s"It can only be applied to an aggregation expression, for example, " +
s"'a.count.distinct which is equivalent with COUNT(DISTINCT a).")
}
}

override private[flink] def children = Seq(child)
}

/** Returns a multiset aggregates. */
case class Collect(child: PlannerExpression) extends Aggregation {

override private[flink] def children: Seq[PlannerExpression] = Seq(child)

override private[flink] def resultType: TypeInformation[_] =
MultisetTypeInfo.getInfoFor(child.resultType)

override def toString: String = s"collect($child)"
}

/** Expression for calling a user-defined (table)aggregate function. */
case class AggFunctionCall(
aggregateFunction: ImperativeAggregateFunction[_, _],
Expand Down

0 comments on commit 2e5face

Please sign in to comment.