Skip to content

Commit

Permalink
Refactor count function to use new function registry.
Browse files Browse the repository at this point in the history
  • Loading branch information
kovrus authored and mergify[bot] committed Apr 2, 2020
1 parent 25ada36 commit 1183ae1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

package io.crate.execution.engine.aggregation.impl;

import com.google.common.collect.ImmutableList;
import io.crate.Streamer;
import io.crate.breaker.RamAccounting;
import io.crate.data.Input;
Expand All @@ -30,13 +29,10 @@
import io.crate.expression.symbol.Literal;
import io.crate.expression.symbol.Symbol;
import io.crate.memory.MemoryManager;
import io.crate.metadata.BaseFunctionResolver;
import io.crate.metadata.FunctionIdent;
import io.crate.metadata.FunctionImplementation;
import io.crate.metadata.FunctionInfo;
import io.crate.metadata.TransactionContext;
import io.crate.metadata.functions.params.FuncParams;
import io.crate.metadata.functions.params.Param;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.FixedWidthType;
Expand All @@ -48,6 +44,9 @@
import java.io.IOException;
import java.util.List;

import static io.crate.metadata.functions.TypeVariableConstraint.typeVariable;
import static io.crate.types.TypeSignature.parseTypeSignature;

public class CountAggregation extends AggregationFunction<CountAggregation.LongState, Long> {

public static final String NAME = "count";
Expand All @@ -58,31 +57,33 @@ public class CountAggregation extends AggregationFunction<CountAggregation.LongS
DataTypes.register(CountAggregation.LongStateType.ID, in -> CountAggregation.LongStateType.INSTANCE);
}

public static final FunctionInfo COUNT_STAR_FUNCTION = new FunctionInfo(new FunctionIdent(NAME,
ImmutableList.of()), DataTypes.LONG, FunctionInfo.Type.AGGREGATE);
public static final FunctionInfo COUNT_STAR_FUNCTION = new FunctionInfo(
new FunctionIdent(NAME, List.of()), DataTypes.LONG, FunctionInfo.Type.AGGREGATE);

public static void register(AggregationImplModule mod) {
mod.register(NAME, new CountAggregationFunctionResolver());
}

private static class CountAggregationFunctionResolver extends BaseFunctionResolver {

CountAggregationFunctionResolver() {
super(FuncParams.builder()
.withVarArgs(Param.ANY).limitVarArgOccurrences(1)
.build());
}

@Override
public FunctionImplementation getForTypes(List<DataType> dataTypes) throws IllegalArgumentException {
if (dataTypes.size() == 0) {
return new CountAggregation(COUNT_STAR_FUNCTION, false);
} else {
return new CountAggregation(
new FunctionInfo(new FunctionIdent(NAME, dataTypes),
DataTypes.LONG, FunctionInfo.Type.AGGREGATE), true);
}
}
mod.register(
Signature.builder()
.name(NAME)
.kind(FunctionInfo.Type.AGGREGATE)
.typeVariableConstraints(typeVariable("V"))
.argumentTypes(parseTypeSignature("V"))
.returnType(DataTypes.LONG.getTypeSignature())
.build(),
args -> new CountAggregation(
new FunctionInfo(
new FunctionIdent(NAME, args),
DataTypes.LONG,
FunctionInfo.Type.AGGREGATE),
true)
);
mod.register(
Signature.builder()
.name(NAME)
.kind(FunctionInfo.Type.AGGREGATE)
.returnType(DataTypes.LONG.getTypeSignature())
.build(),
args -> new CountAggregation(COUNT_STAR_FUNCTION, false)
);
}

private CountAggregation(FunctionInfo info, boolean hasArgs) {
Expand Down Expand Up @@ -126,7 +127,7 @@ public Symbol normalizeSymbol(Function function, TransactionContext txnCtx) {
if (((Input) arg).value() == null) {
return Literal.of(0L);
} else {
return new Function(COUNT_STAR_FUNCTION, ImmutableList.of());
return new Function(COUNT_STAR_FUNCTION, List.of());
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions sql/src/main/java/io/crate/expression/AbstractFunctionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@
import io.crate.metadata.FunctionImplementation;
import io.crate.metadata.FunctionName;
import io.crate.metadata.FunctionResolver;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataType;
import org.elasticsearch.common.inject.AbstractModule;
import org.elasticsearch.common.inject.TypeLiteral;
import org.elasticsearch.common.inject.multibindings.MapBinder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public abstract class AbstractFunctionModule<T extends FunctionImplementation> extends AbstractModule {

Expand All @@ -57,6 +61,13 @@ public void register(FunctionName qualifiedName, FunctionResolver functionResolv
resolver.put(qualifiedName, functionResolver);
}

public void register(Signature signature, Function<List<DataType>, FunctionImplementation> factory) {
List<FuncResolver> functions = functionImplementations.computeIfAbsent(
signature.getName(),
k -> new ArrayList<>());
functions.add(new FuncResolver(signature, factory));
}

public abstract void configureFunctions();

@Override
Expand Down

0 comments on commit 1183ae1

Please sign in to comment.