Skip to content

Commit

Permalink
allow to order by distance
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasender committed Jul 1, 2014
1 parent 8704a16 commit f156649
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 46 deletions.
2 changes: 1 addition & 1 deletion core/src/main/java/io/crate/types/DataTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,4 @@ public static DataType ofName(String name) {
}
return dataType;
}
}
}
63 changes: 41 additions & 22 deletions sql/src/main/java/io/crate/analyze/SelectStatementAnalyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,7 @@ private void addSorting(List<SortItem> orderBy, SelectAnalysis context) {

int i = 0;
for (SortItem sortItem : orderBy) {
Symbol s = process(sortItem, context);
Symbol refOutput = ordinalOutputReference(context.outputSymbols(), s, "ORDER BY");
s = Objects.firstNonNull(refOutput, s);
if (s.symbolType() == SymbolType.REFERENCE && !DataTypes.PRIMITIVE_TYPES.contains(((Reference)s).valueType())) {
throw new IllegalArgumentException(
String.format("Cannot order by '%s': invalid data type '%s'",
SymbolFormatter.format(s),
((Reference) s).valueType()));
}
sortSymbols.add(s);
sortSymbols.add(process(sortItem, context));
switch (sortItem.getNullOrdering()) {
case FIRST:
context.nullsFirst()[i] = true;
Expand Down Expand Up @@ -283,8 +274,12 @@ protected Symbol visitSortItem(SortItem node, SelectAnalysis context) {
if (sortSymbol.symbolType() == SymbolType.PARAMETER) {
sortSymbol = Literal.fromParameter((Parameter)sortSymbol);
}
if (sortSymbol.symbolType() == SymbolType.LITERAL && DataTypes.NUMERIC_PRIMITIVE_TYPES.contains(((Literal)sortSymbol).valueType())) {
// deref
sortSymbol = ordinalOutputReference(context.outputSymbols(), sortSymbol, "ORDER BY");
}
// validate sortSymbol
sortSymbolValidator.process(sortSymbol, context.table);
sortSymbolValidator.process(sortSymbol, new SortSymbolValidator.SortContext(context.table));
return sortSymbol;
}

Expand Down Expand Up @@ -322,37 +317,61 @@ public Void visitAggregation(Aggregation symbol, AggregationSearcherContext cont
/**
* validate that sortSymbols don't contain partition by columns
*/
static class SortSymbolValidator extends SymbolVisitor<TableInfo, Void> {
static class SortSymbolValidator extends SymbolVisitor<SortSymbolValidator.SortContext, Void> {

static class SortContext {
private final TableInfo tableInfo;
private boolean inFunction;
public SortContext(TableInfo tableInfo) {
this.tableInfo = tableInfo;
this.inFunction = false;
}
}

@Override
public Void visitFunction(Function symbol, TableInfo context) {
for (Symbol arg : symbol.arguments()) {
process(arg, context);

public Void visitFunction(Function symbol, SortContext context) {
try {
context.inFunction = true;
if (!DataTypes.PRIMITIVE_TYPES.contains(symbol.valueType())) {
throw new UnsupportedOperationException(
String.format(Locale.ENGLISH,
"Cannot ORDER BY '%s': invalid return type '%s'.",
SymbolFormatter.format(symbol),
symbol.valueType())
);
}
for (Symbol arg : symbol.arguments()) {
process(arg, context);
}
} finally {
context.inFunction = false;
}
return null;
}

@Override
public Void visitReference(Reference symbol, TableInfo context) {
if (context.partitionedBy().contains(symbol.info().ident().columnIdent())) {
public Void visitReference(Reference symbol, SortContext context) {
if (context.tableInfo.partitionedBy().contains(symbol.info().ident().columnIdent())) {
throw new UnsupportedOperationException(
SymbolFormatter.format(
"cannot use partitioned column %s in ORDER BY clause",
symbol));
}
if (!DataTypes.PRIMITIVE_TYPES.contains(symbol.info().type())) {
// if we are in a function, we do not need to check the data type.
// the function will do that for us.
if (!context.inFunction && !DataTypes.PRIMITIVE_TYPES.contains(symbol.info().type())) {
throw new UnsupportedOperationException(
String.format(Locale.ENGLISH,
"cannot sort on columns of type '%s'",
symbol.info().type().getName())
"Cannot ORDER BY '%s': invalid data type '%s'.",
SymbolFormatter.format(symbol),
symbol.valueType())
);
}
return null;
}

@Override
public Void visitSymbol(Symbol symbol, TableInfo context) {
public Void visitSymbol(Symbol symbol, SortContext context) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,27 +289,16 @@ public Void visitFunction(Function symbol, OrderByContext context) {
if (symbol.info().ident().name().equals(DistanceFunction.NAME)) {
Symbol referenceSymbol = symbol.arguments().get(0);
Symbol valueSymbol = symbol.arguments().get(1);
if (referenceSymbol.symbolType().isValueSymbol()) {
if (!valueSymbol.symbolType().isValueSymbol()) {
throw new IllegalArgumentException(SymbolFormatter.format(
"Can't use \"%s\" in the ORDER BY clause. Requires one column reference and one literal", symbol));
}
Symbol tmp = referenceSymbol;
referenceSymbol = valueSymbol;
valueSymbol = tmp;
}

SortOrder sortOrder = new SortOrder(context.reverseFlag(), context.nullFirst());
Reference reference;
Input input;
try {
reference = (Reference) referenceSymbol;
input = (Input) valueSymbol;
} catch (ClassCastException e) {
if (referenceSymbol.symbolType() != SymbolType.REFERENCE || !(valueSymbol instanceof Input)) {
throw new IllegalArgumentException(SymbolFormatter.format(
"Can't use \"%s\" in the ORDER BY clause. Requires one column reference and one literal", symbol), e);
"Can't use \"%s\" in the ORDER BY clause. " +
"Requires one column reference and one literal, in that order.", symbol));
}

SortOrder sortOrder = new SortOrder(context.reverseFlag(), context.nullFirst());
Reference reference = (Reference) referenceSymbol;
Input input = (Input) valueSymbol;
try {
context.builder.startObject().startObject("_geo_distance")
.field(reference.info().ident().columnIdent().fqn(), input.value())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ public Symbol normalizeSymbol(Function symbol) {
DataType arg1Type = DataTypeVisitor.fromSymbol(arg1);
DataType arg2Type = DataTypeVisitor.fromSymbol(arg2);

boolean arg1IsReference = true;
boolean literalConverted = false;
short numLiterals = 0;

if (arg1.symbolType().isValueSymbol()) {
numLiterals++;
arg1IsReference = false;
if (!arg1Type.equals(DataTypes.GEO_POINT)) {
literalConverted = true;
arg1 = Literal.toLiteral(arg1, DataTypes.GEO_POINT);
Expand All @@ -138,6 +140,10 @@ public Symbol normalizeSymbol(Function symbol) {
return Literal.newLiteral(evaluate((Input) arg1, (Input) arg2));
}

// ensure reference is the first argument.
if (!arg1IsReference) {
return new Function(geoPointInfo, Arrays.asList(arg2, arg1));
}
if (literalConverted) {
return new Function(geoPointInfo, Arrays.asList(arg1, arg2));
}
Expand Down
6 changes: 6 additions & 0 deletions sql/src/test/java/io/crate/analyze/BaseAnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ TEST_DOC_TRANSACTIONS_TABLE_IDENT, RowGranularity.DOC, new Routing())
.add("timestamp", DataTypes.TIMESTAMP, null)
.build();

static final TableIdent TEST_DOC_LOCATIONS_TABLE_IDENT = new TableIdent(null, "locations");
static final TableInfo TEST_DOC_LOCATIONS_TABLE_INFO = TestingTableInfo.builder(TEST_DOC_LOCATIONS_TABLE_IDENT, RowGranularity.DOC, shardRouting)
.add("id", DataTypes.LONG, null)
.add("loc", DataTypes.GEO_POINT, null)
.build();

static final FunctionInfo ABS_FUNCTION_INFO = new FunctionInfo(
new FunctionIdent("abs", Arrays.<DataType>asList(DataTypes.LONG)),
DataTypes.LONG);
Expand Down
44 changes: 43 additions & 1 deletion sql/src/test/java/io/crate/analyze/SelectAnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.crate.operation.scalar.CollectionCountFunction;
import io.crate.operation.scalar.ScalarFunctionModule;
import io.crate.operation.scalar.arithmetic.AddFunction;
import io.crate.operation.scalar.geo.DistanceFunction;
import io.crate.planner.RowGranularity;
import io.crate.planner.symbol.*;
import io.crate.types.ArrayType;
Expand Down Expand Up @@ -102,6 +103,8 @@ protected void bindSchemas() {
.thenReturn(TEST_MULTIPLE_PARTITIONED_TABLE_INFO);
when(schemaInfo.getTableInfo(TEST_DOC_TRANSACTIONS_TABLE_IDENT.name()))
.thenReturn(TEST_DOC_TRANSACTIONS_TABLE_INFO);
when(schemaInfo.getTableInfo(TEST_DOC_LOCATIONS_TABLE_IDENT.name()))
.thenReturn(TEST_DOC_LOCATIONS_TABLE_INFO);
schemaBinder.addBinding(DocSchemaInfo.NAME).toInstance(schemaInfo);
}

Expand Down Expand Up @@ -1271,8 +1274,47 @@ public void testPositionalArgumentGroupByArrayType() throws Exception {
analyze("SELECT sum(id), friends FROM users GROUP BY 2");
}

@Test(expected = IllegalArgumentException.class)
@Test(expected = UnsupportedOperationException.class)
public void testPositionalArgumentOrderByArrayType() throws Exception {
analyze("SELECT id, friends FROM users ORDER BY 2");
}

@Test
public void testOrderByDistanceAlias() throws Exception {
String stmt = "SELECT distance(loc, 'POINT(-0.1275 51.507222)') AS distance_to_london " +
"FROM locations " +
"ORDER BY distance_to_london";
testDistanceOrderBy(stmt);
}

@Test
public void testOrderByDistancePositionalArgument() throws Exception {
String stmt = "SELECT distance(loc, 'POINT(-0.1275 51.507222)') " +
"FROM locations " +
"ORDER BY 1";
testDistanceOrderBy(stmt);
}

@Test
public void testOrderByDistanceExplicitly() throws Exception {
String stmt = "SELECT distance(loc, 'POINT(-0.1275 51.507222)') " +
"FROM locations " +
"ORDER BY distance(loc, 'POINT(-0.1275 51.507222)')";
testDistanceOrderBy(stmt);
}

@Test
public void testOrderByDistancePermutatedExplicitly() throws Exception {
String stmt = "SELECT distance('POINT(-0.1275 51.507222)', loc) " +
"FROM locations " +
"ORDER BY distance('POINT(-0.1275 51.507222)', loc)";
testDistanceOrderBy(stmt);
}

private void testDistanceOrderBy(String stmt) throws Exception{
SelectAnalysis analysis = (SelectAnalysis) analyze(stmt);
assertTrue(analysis.isSorted());
assertEquals(DistanceFunction.NAME, ((Function)analysis.sortSymbols().get(0)).info().ident().name());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ public void testDistanceGteQuerySwappedArgs() throws Exception {
new FunctionIdent(DistanceFunction.NAME, Arrays.<DataType>asList(DataTypes.GEO_POINT, DataTypes.GEO_POINT)),
DataTypes.DOUBLE),
Arrays.<Symbol>asList(
createReference("location", DataTypes.GEO_POINT),
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)"))
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)")),
createReference("location", DataTypes.GEO_POINT)
)
);
Function whereClause = new Function(
Expand Down Expand Up @@ -568,6 +568,31 @@ public void testConvertESSearchNodeWithOrderByDistance() throws Exception {
"{\"_source\":{\"include\":[\"name\"]},\"query\":{\"match_all\":{}},\"sort\":[{\"_geo_distance\":{\"location\":[10.0,20.0],\"order\":\"asc\"}}],\"from\":0,\"size\":10000}"));
}

@Test(expected = IllegalArgumentException.class)
public void testConvertESSearchNodeWithOrderByDistanceSwappedArgs() throws Exception {
// ref must be the first argument
Function distanceFunction = new Function(
new FunctionInfo(
new FunctionIdent(DistanceFunction.NAME, Arrays.<DataType>asList(DataTypes.GEO_POINT, DataTypes.GEO_POINT)),
DataTypes.DOUBLE),
Arrays.<Symbol>asList(
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)")),
createReference("location", DataTypes.GEO_POINT)
)
);
ESSearchNode searchNode = new ESSearchNode(
new String[]{characters.name()},
ImmutableList.<Symbol>of(name_ref),
ImmutableList.<Symbol>of(distanceFunction),
new boolean[] { false },
new Boolean[] { null },
null,
null,
WhereClause.MATCH_ALL,
null);
BytesReference reference = generator.convert(searchNode);
}

@Test (expected = IllegalArgumentException.class)
public void testConvertESSearchNodeWithOrderByDistanceTwoReferences() throws Exception {
Function distanceFunction = new Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.crate.action.sql.SQLActionException;
import io.crate.action.sql.SQLResponse;
import io.crate.test.integration.CrateIntegrationTest;
import io.crate.testing.TestingHelpers;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthStatus;
import org.elasticsearch.action.admin.indices.alias.exists.AliasesExistResponse;
import org.elasticsearch.action.admin.indices.exists.indices.IndicesExistsRequest;
Expand Down Expand Up @@ -4142,6 +4143,16 @@ public void testGeoTypeQueries() throws Exception {
assertThat(result1, is(0.0d));
assertThat(result2, is(156098.81231186818D));

String stmt = "SELECT id " +
"FROM t " +
"ORDER BY distance(p, 'POINT(30.0 30.0)')";
execute(stmt);
assertThat(response.rowCount(), is(2L));
String expected =
"2\n" +
"1\n";
assertEquals(expected, TestingHelpers.printedTable(response.rows()));

execute("select p from t where distance(p, 'POINT (11 21)') > 0.0");
List<Double> row = (List<Double>) response.rows()[0][0];
assertThat(row.get(0), is(10.0d));
Expand Down Expand Up @@ -4280,4 +4291,5 @@ public void testInsertFormQueryWithGeoType() throws Exception {
execute("insert into t2 (p) (select p from t)");
assertThat(response.rowCount(), is(1L));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import static io.crate.testing.TestingHelpers.assertLiteralSymbol;
import static io.crate.testing.TestingHelpers.createReference;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.*;

Expand Down Expand Up @@ -151,7 +152,6 @@ public void testNormalizeWithValidRefAndStringLiteral() throws Exception {
createReference("foo", DataTypes.GEO_POINT),
Literal.newLiteral("POINT(10 20)")
));

assertLiteralSymbol(symbol.arguments().get(1),
new Double[]{10.0d, 20.0d}, DataTypes.GEO_POINT);

Expand All @@ -160,10 +160,44 @@ public void testNormalizeWithValidRefAndStringLiteral() throws Exception {
Literal.newLiteral("POINT(10 20)"),
createReference("foo", DataTypes.GEO_POINT)
));
assertLiteralSymbol(symbol.arguments().get(0),
assertLiteralSymbol(symbol.arguments().get(1),
new Double[] { 10.0d, 20.0d }, DataTypes.GEO_POINT);
}

@Test
public void testNormalizeWithValidRefAndGeoPointLiteral() throws Exception {
Function symbol = (Function) normalize(Arrays.<Symbol>asList(
createReference("foo", DataTypes.GEO_POINT),
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)"))
));
assertLiteralSymbol(symbol.arguments().get(1),
new Double[]{10.0d, 20.0d}, DataTypes.GEO_POINT);

// args reversed
symbol = (Function) normalize(Arrays.<Symbol>asList(
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)")),
createReference("foo", DataTypes.GEO_POINT)
));
assertLiteralSymbol(symbol.arguments().get(1),
new Double[] { 10.0d, 20.0d }, DataTypes.GEO_POINT);
}

@Test
public void testNormalizeWithValidGeoPointLiterals() throws Exception {
Literal symbol = (Literal) normalize(Arrays.<Symbol>asList(
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)")),
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (30 40)"))
));
assertThat(symbol.value(), instanceOf(Double.class));

// args reversed
symbol = (Literal) normalize(Arrays.<Symbol>asList(
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (30 40)")),
Literal.newLiteral(DataTypes.GEO_POINT, DataTypes.GEO_POINT.value("POINT (10 20)"))
));
assertThat(symbol.value(), instanceOf(Double.class));
}

@Test
public void testNormalizeWithTwoValidRefs() throws Exception {
List<Symbol> args = Arrays.<Symbol>asList(
Expand All @@ -189,4 +223,4 @@ public void testWithNullValue() throws Exception {
distanceSymbol = normalize(Lists.reverse(args));
assertNull(((Literal) distanceSymbol).value());
}
}
}

0 comments on commit f156649

Please sign in to comment.