Skip to content

Commit

Permalink
Rewrite -dot_product to pgvector <#> operator in PostgreSQL
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Aug 13, 2024
1 parent 8ec18d2 commit d695773
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import io.trino.plugin.jdbc.expression.RewriteIn;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping;
import io.trino.plugin.postgresql.rule.RewriteDotProductFunction;
import io.trino.plugin.postgresql.rule.RewriteStringReverseFunction;
import io.trino.plugin.postgresql.rule.RewriteVectorDistanceFunction;
import io.trino.spi.TrinoException;
Expand Down Expand Up @@ -348,7 +349,7 @@ public PostgreSqlClient(
.add(new RewriteStringReverseFunction())
.add(new RewriteVectorDistanceFunction("euclidean_distance", "<->"))
.add(new RewriteVectorDistanceFunction("cosine_distance", "<=>"))
// TODO Rewrite Trino -dot_product to pgvector <#> operator
.add(new RewriteDotProductFunction())
.build());

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.postgresql.rule;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.projection.ProjectFunctionRule;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.QueryParameter;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;

import java.sql.Types;
import java.util.Optional;

import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName;
import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type;
import static io.trino.plugin.postgresql.rule.RewriteVectorDistanceFunction.isArrayTypeWithRealOrDouble;
import static io.trino.spi.type.DoubleType.DOUBLE;

public final class RewriteDotProductFunction
implements ProjectFunctionRule<JdbcExpression, ParameterizedExpression>
{
private static final Capture<ConnectorExpression> CALL = newCapture();

private static final Pattern<Call> PATTERN = call()
.with(functionName().equalTo(new FunctionName("$negate")))
.with(type().matching(type -> type == DOUBLE))
.with(argumentCount().equalTo(1))
.with(argument(0).matching(expression().capturedAs(CALL).matching(expression -> expression instanceof Call call
&& call.getFunctionName().equals(new FunctionName("dot_product"))
&& call.getArguments().size() == 2
&& call.getArguments().stream().allMatch(argument -> isArrayTypeWithRealOrDouble(argument.getType())))));

@Override
public Pattern<? extends ConnectorExpression> getPattern()
{
return PATTERN;
}

@Override
public Optional<JdbcExpression> rewrite(ConnectorExpression projectionExpression, Captures captures, RewriteContext<ParameterizedExpression> context)
{
ConnectorExpression call = captures.get(CALL);

Optional<ParameterizedExpression> leftExpression = RewriteVectorDistanceFunction.rewrite(call.getChildren().getFirst(), context);
if (leftExpression.isEmpty()) {
return Optional.empty();
}

Optional<ParameterizedExpression> rightExpression = RewriteVectorDistanceFunction.rewrite(call.getChildren().get(1), context);
if (rightExpression.isEmpty()) {
return Optional.empty();
}

return Optional.of(new JdbcExpression(
"%s <#> %s".formatted(leftExpression.get().expression(), rightExpression.get().expression()),
ImmutableList.<QueryParameter>builder()
.addAll(leftExpression.get().parameters())
.addAll(rightExpression.get().parameters())
.build(),
new JdbcTypeHandle(
Types.DOUBLE,
Optional.of("double"),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public Optional<JdbcExpression> rewrite(ConnectorExpression projectionExpression
Optional.empty())));
}

private static Optional<ParameterizedExpression> rewrite(ConnectorExpression expression, RewriteContext<ParameterizedExpression> context)
public static Optional<ParameterizedExpression> rewrite(ConnectorExpression expression, RewriteContext<ParameterizedExpression> context)
{
if (expression instanceof Constant constant) {
Type elementType = ((ArrayType) constant.getType()).getElementType();
Expand Down Expand Up @@ -140,7 +140,7 @@ private static Optional<ParameterizedExpression> rewrite(ConnectorExpression exp
return Optional.of(translatedArgument.orElseThrow());
}

private static boolean isArrayTypeWithRealOrDouble(Type type)
public static boolean isArrayTypeWithRealOrDouble(Type type)
{
return type instanceof ArrayType arrayType && (arrayType.getElementType() == REAL || arrayType.getElementType() == DOUBLE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,13 @@ void testDotProductCompatibility()
TestView view = new TestView(postgreSqlServer::execute, "test_dot_product", "SELECT v <#> '[7,8,9]' FROM " + table.getName())) {
postgreSqlServer.execute("INSERT INTO " + table.getName() + " VALUES (1, '[1,2,3]'), (2, '[4,5,6]')");

// TODO Add support for projection pushdown with dot_product function
// The minus sign is needed because <#> returns the negative inner product. Postgres only supports ASC order index scans on operators.
assertThat(query("SELECT -dot_product(v, ARRAY[7,8,9]) FROM " + table.getName()))
.matches("SELECT * FROM tpch." + view.getName())
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();

assertThat(query("SELECT id FROM " + table.getName() + " ORDER BY -dot_product(v, ARRAY[7,8,9]) LIMIT 1"))
.isNotFullyPushedDown(ProjectNode.class);
.isFullyPushedDown();
}
}

Expand Down

0 comments on commit d695773

Please sign in to comment.