Skip to content

Commit

Permalink
Use new func registry for temporal arithmetic functions.
Browse files Browse the repository at this point in the history
- interval and timestamps
  • Loading branch information
kovrus authored and mergify[bot] committed Apr 17, 2020
1 parent bfee8b5 commit cd3c158
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

package io.crate.execution.engine.window;

import io.crate.expression.scalar.arithmetic.IntervalTimestampScalar;
import io.crate.expression.scalar.arithmetic.IntervalTimestampArithmeticScalar;
import io.crate.types.ByteType;
import io.crate.types.DataType;
import io.crate.types.DoubleType;
Expand All @@ -33,6 +33,7 @@
import io.crate.types.ShortType;
import io.crate.types.TimestampType;

import java.util.List;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;

Expand All @@ -54,8 +55,8 @@ static BiFunction getAddFunction(DataType fstArgDataType, DataType sndArgDataTyp
case TimestampType.ID_WITH_TZ:
case TimestampType.ID_WITHOUT_TZ:
if (IntervalType.ID == sndArgDataType.id()) {
return new IntervalTimestampScalar(
"+", "add-interval", fstArgDataType, sndArgDataType, fstArgDataType);
return new IntervalTimestampArithmeticScalar(
"+", "add-interval", List.of(fstArgDataType, sndArgDataType), fstArgDataType);
}
return ADD_LONG_FUNCTION;
case DoubleType.ID:
Expand All @@ -78,8 +79,8 @@ static BiFunction getSubtractFunction(DataType fstArgDataType, DataType sndArgDa
case TimestampType.ID_WITH_TZ:
case TimestampType.ID_WITHOUT_TZ:
if (IntervalType.ID == sndArgDataType.id()) {
return new IntervalTimestampScalar(
"-", "sub-interval", fstArgDataType, sndArgDataType, fstArgDataType);
return new IntervalTimestampArithmeticScalar(
"-", "sub-interval", List.of(fstArgDataType, sndArgDataType), fstArgDataType);
}
return SUB_LONG_FUNCTION;
case DoubleType.ID:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.crate.expression.scalar.arithmetic.CeilFunction;
import io.crate.expression.scalar.arithmetic.ExpFunction;
import io.crate.expression.scalar.arithmetic.FloorFunction;
import io.crate.expression.scalar.arithmetic.IntervalArithmeticScalar;
import io.crate.expression.scalar.arithmetic.IntervalTimestampArithmeticScalar;
import io.crate.expression.scalar.arithmetic.LogFunction;
import io.crate.expression.scalar.arithmetic.MapFunction;
import io.crate.expression.scalar.arithmetic.NegateFunctions;
Expand Down Expand Up @@ -143,6 +145,8 @@ protected void configure() {
RegexpReplaceFunction.register(this);

ArithmeticFunctions.register(this);
IntervalTimestampArithmeticScalar.register(this);
IntervalArithmeticScalar.register(this);

DistanceFunction.register(this);
WithinFunction.register(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,14 @@
package io.crate.expression.scalar.arithmetic;

import com.google.common.collect.ImmutableList;
import io.crate.common.collections.Lists2;
import io.crate.data.Input;
import io.crate.expression.scalar.ScalarFunctionModule;
import io.crate.expression.symbol.FuncArg;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.Symbol;
import io.crate.metadata.BaseFunctionResolver;
import io.crate.metadata.FunctionIdent;
import io.crate.metadata.FunctionImplementation;
import io.crate.metadata.FunctionInfo;
import io.crate.metadata.Scalar;
import io.crate.metadata.TransactionContext;
import io.crate.metadata.functions.params.FuncParams;
import io.crate.metadata.functions.params.Param;
import io.crate.types.ByteType;
Expand All @@ -43,25 +39,20 @@
import io.crate.types.DoubleType;
import io.crate.types.FloatType;
import io.crate.types.IntegerType;
import io.crate.types.IntervalType;
import io.crate.types.LongType;
import io.crate.types.ShortType;
import io.crate.types.TimestampType;
import org.elasticsearch.common.util.set.Sets;
import org.joda.time.Period;

import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;

public class ArithmeticFunctions {

private static final Param ARITHMETIC_TYPE = Param.of(
DataTypes.NUMERIC_PRIMITIVE_TYPES, DataTypes.TIMESTAMPZ, DataTypes.TIMESTAMP, DataTypes.INTERVAL, DataTypes.UNDEFINED);
DataTypes.NUMERIC_PRIMITIVE_TYPES, DataTypes.TIMESTAMPZ, DataTypes.TIMESTAMP, DataTypes.UNDEFINED);

public static class Names {
public static final String ADD = "add";
Expand Down Expand Up @@ -175,25 +166,9 @@ static final class ArithmeticFunctionResolver extends BaseFunctionResolver {
this.features = features;
}

@Nullable
@Override
public List<DataType> getSignature(List<? extends FuncArg> dataTypes) {
if (dataTypes.size() == 2) {
DataType fst = dataTypes.get(0).valueType();
DataType snd = dataTypes.get(1).valueType();

if ((isInterval(fst) && isTimestamp(snd)) ||
(isTimestamp(fst) && isInterval(snd))) {
return Lists2.map(dataTypes, FuncArg::valueType);
}
}
return super.getSignature(dataTypes);
}

@Override
public FunctionImplementation getForTypes(List<DataType> dataTypes) throws IllegalArgumentException {
assert dataTypes.size() == 2 : "Arithmetic operator must receive two arguments";

DataType<?> fst = dataTypes.get(0);
DataType<?> snd = dataTypes.get(1);

Expand All @@ -213,11 +188,6 @@ public FunctionImplementation getForTypes(List<DataType> dataTypes) throws Illeg
case IntegerType.ID:
scalar = new BinaryScalar<>(integerFunction, name, DataTypes.INTEGER, features);
break;

case IntervalType.ID:
scalar = new IntervalArithmeticScalar(operator, name);
break;

case LongType.ID:
case TimestampType.ID_WITH_TZ:
case TimestampType.ID_WITHOUT_TZ:
Expand All @@ -231,28 +201,9 @@ public FunctionImplementation getForTypes(List<DataType> dataTypes) throws Illeg
return scalar;
}

if (isInterval(fst) && isTimestamp(snd)) {
return new IntervalTimestampScalar(operator, name, fst, snd, snd);
}
if (isTimestamp(fst) && isInterval(snd)) {
return new IntervalTimestampScalar(operator, name, fst, snd, fst);
}

throw new UnsupportedOperationException(
String.format(Locale.ENGLISH, "Arithmetic operation are not supported for type %s %s", fst, snd));
}

private static boolean isInterval(DataType d) {
return d.id() == IntervalType.ID;
}

private static boolean isTimestamp(DataType d) {
return TIMESTAMP_IDS.contains(d.id());
}

static final Set<Integer> TIMESTAMP_IDS = Sets.newHashSet(DataTypes.TIMESTAMP.id(),
DataTypes.TIMESTAMPZ.id());

}

public static Function of(String name, Symbol first, Symbol second, Set<FunctionInfo.Feature> features) {
Expand All @@ -267,43 +218,4 @@ public static Function of(String name, Symbol first, Symbol second, Set<Function
ImmutableList.of(first, second)
);
}

private static final class IntervalArithmeticScalar extends Scalar<Period, Object> {

private final FunctionInfo info;
private final BiFunction<Period, Period, Period> operation;

IntervalArithmeticScalar(String operator, String name) {
this.info = new FunctionInfo(
new FunctionIdent(name, Arrays.asList(DataTypes.INTERVAL, DataTypes.INTERVAL)), DataTypes.INTERVAL);
switch (operator) {
case "+":
operation = Period::plus;
break;
case "-":
operation = Period::minus;
break;
default:
operation = (a,b) -> {
throw new IllegalArgumentException("Unsupported operator for interval " + operator);
};
}
}

@Override
public Period evaluate(TransactionContext txnCtx, Input<Object>... args) {
Period fst = (Period) args[0].value();
Period snd = (Period) args[1].value();

if (fst == null || snd == null) {
return null;
}
return operation.apply(fst, snd);
}

@Override
public FunctionInfo info() {
return this.info;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to Crate under one or more contributor license agreements.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership. Crate licenses this file
* to you under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*
* However, if you have executed another commercial license agreement
* with Crate these terms will supersede the license and you may use the
* software solely pursuant to the terms of the relevant commercial
* agreement.
*/

package io.crate.expression.scalar.arithmetic;

import io.crate.data.Input;
import io.crate.expression.scalar.ScalarFunctionModule;
import io.crate.metadata.FunctionIdent;
import io.crate.metadata.FunctionInfo;
import io.crate.metadata.Scalar;
import io.crate.metadata.TransactionContext;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataTypes;
import org.joda.time.Period;

import java.util.List;
import java.util.function.BiFunction;

public class IntervalArithmeticScalar extends Scalar<Period, Object> {

public static void register(ScalarFunctionModule module) {
module.register(
Signature.scalar(
ArithmeticFunctions.Names.ADD,
DataTypes.INTERVAL.getTypeSignature(),
DataTypes.INTERVAL.getTypeSignature(),
DataTypes.INTERVAL.getTypeSignature()
),
args -> new IntervalArithmeticScalar("+", ArithmeticFunctions.Names.ADD)
);
module.register(
Signature.scalar(
ArithmeticFunctions.Names.SUBTRACT,
DataTypes.INTERVAL.getTypeSignature(),
DataTypes.INTERVAL.getTypeSignature(),
DataTypes.INTERVAL.getTypeSignature()
),
args -> new IntervalArithmeticScalar("-", ArithmeticFunctions.Names.SUBTRACT)
);
}

private final FunctionInfo info;
private final BiFunction<Period, Period, Period> operation;

IntervalArithmeticScalar(String operator, String name) {
info = new FunctionInfo(
new FunctionIdent(
name,
List.of(DataTypes.INTERVAL, DataTypes.INTERVAL)),
DataTypes.INTERVAL);

switch (operator) {
case "+":
operation = Period::plus;
break;
case "-":
operation = Period::minus;
break;
default:
operation = (a, b) -> {
throw new IllegalArgumentException("Unsupported operator for interval " + operator);
};
}
}

@Override
public FunctionInfo info() {
return this.info;
}

@Override
public Period evaluate(TransactionContext txnCtx, Input<Object>[] args) {
Period fst = (Period) args[0].value();
Period snd = (Period) args[1].value();

if (fst == null || snd == null) {
return null;
}
return operation.apply(fst, snd);
}
}
Loading

0 comments on commit cd3c158

Please sign in to comment.