Skip to content

Commit

Permalink
fix select subquery wrong result (apache#12658)
Browse files Browse the repository at this point in the history
* fix select subquery wrong result

* fix unit test

* fix pagination route logic

* fix unit test

* fix unit test

* add federate route test

* optimize route logic

* add todo

* add todo
  • Loading branch information
strongduanmu authored Sep 24, 2021
1 parent a199e7a commit ffd8576
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.sharding.route.engine.condition;

import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.apache.shardingsphere.sharding.route.engine.condition.value.ShardingConditionValue;

Expand All @@ -32,4 +33,7 @@
public class ShardingCondition {

private final List<ShardingConditionValue> values = new LinkedList<>();

@Setter
private int startIndex;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.shardingsphere.sharding.route.engine.condition;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
Expand All @@ -31,16 +30,21 @@
import org.apache.shardingsphere.sharding.rule.BindingTableRule;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SafeNumberOperationUtil;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* Sharding conditions.
*/
@RequiredArgsConstructor
@Getter
@ToString
public final class ShardingConditions {
Expand All @@ -51,6 +55,15 @@ public final class ShardingConditions {

private final ShardingRule rule;

private final boolean subqueryContainsShardingCondition;

public ShardingConditions(final List<ShardingCondition> conditions, final SQLStatementContext<?> sqlStatementContext, final ShardingRule rule) {
this.conditions = conditions;
this.sqlStatementContext = sqlStatementContext;
this.rule = rule;
subqueryContainsShardingCondition = isSubqueryContainsShardingCondition(conditions, sqlStatementContext);
}

/**
* Judge sharding conditions is always false or not.
*
Expand Down Expand Up @@ -98,11 +111,38 @@ private Optional<ShardingCondition> findUniqueShardingCondition(final List<Shard
*/
public boolean isNeedMerge() {
boolean selectContainsSubquery = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isContainsSubquery();
boolean insertSelectContainsSubquery = sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
boolean insertSelectContainsSubquery = sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()
&& ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext().isContainsSubquery();
return (selectContainsSubquery || insertSelectContainsSubquery) && !rule.getShardingLogicTableNames(sqlStatementContext.getTablesContext().getTableNames()).isEmpty();
}

private boolean isSubqueryContainsShardingCondition(final List<ShardingCondition> conditions, final SQLStatementContext<?> sqlStatementContext) {
Collection<SelectStatement> selectStatements = getSelectStatements(sqlStatementContext);
if (selectStatements.size() > 1) {
Map<Integer, List<ShardingCondition>> startIndexShardingConditions = conditions.stream().collect(Collectors.groupingBy(ShardingCondition::getStartIndex));
for (SelectStatement each : selectStatements) {
if (!each.getWhere().isPresent() || !startIndexShardingConditions.containsKey(each.getWhere().get().getExpr().getStartIndex())) {
return false;
}
}
}
return true;
}

private Collection<SelectStatement> getSelectStatements(final SQLStatementContext<?> sqlStatementContext) {
Collection<SelectStatement> result = new LinkedList<>();
if (sqlStatementContext instanceof SelectStatementContext) {
result.add(((SelectStatementContext) sqlStatementContext).getSqlStatement());
result.addAll(((SelectStatementContext) sqlStatementContext).getSubquerySegments().stream().map(SubquerySegment::getSelect).collect(Collectors.toList()));
}
if (sqlStatementContext instanceof InsertStatementContext && null != ((InsertStatementContext) sqlStatementContext).getInsertSelectContext()) {
SelectStatementContext selectStatementContext = ((InsertStatementContext) sqlStatementContext).getInsertSelectContext().getSelectStatementContext();
result.add(selectStatementContext.getSqlStatement());
result.addAll(selectStatementContext.getSubquerySegments().stream().map(SubquerySegment::getSelect).collect(Collectors.toList()));
}
return result;
}

/**
* Judge whether all sharding conditions are same or not.
*
Expand All @@ -116,7 +156,7 @@ public boolean isSameShardingCondition() {
return false;
}
}
return conditions.size() <= 1;
return subqueryContainsShardingCondition && conditions.size() == 1;
}

private boolean isSameShardingCondition(final ShardingRule shardingRule, final ShardingCondition shardingCondition1, final ShardingCondition shardingCondition2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,19 @@ public List<ShardingCondition> createShardingConditions(final SQLStatementContex
return result;
}

private Collection<ShardingCondition> createShardingConditions(final SQLStatementContext<?> sqlStatementContext, final ExpressionSegment expressionSegment, final List<Object> parameters) {
Collection<AndPredicate> andPredicates = ExpressionExtractUtil.getAndPredicates(expressionSegment);
private Collection<ShardingCondition> createShardingConditions(final SQLStatementContext<?> sqlStatementContext, final ExpressionSegment expression, final List<Object> parameters) {
Collection<AndPredicate> andPredicates = ExpressionExtractUtil.getAndPredicates(expression);
Map<String, String> columnTableNames = getColumnTableNames(sqlStatementContext, andPredicates);
Collection<ShardingCondition> result = new LinkedList<>();
for (AndPredicate each : andPredicates) {
Map<Column, Collection<ShardingConditionValue>> shardingConditionValues = createShardingConditionValueMap(each.getPredicates(), parameters, columnTableNames);
if (shardingConditionValues.isEmpty()) {
return Collections.emptyList();
}
result.add(createShardingCondition(shardingConditionValues));
ShardingCondition shardingCondition = createShardingCondition(shardingConditionValues);
// TODO remove startIndex when federation has perfect support for subquery
shardingCondition.setStartIndex(expression.getStartIndex());
result.add(shardingCondition);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,29 @@ private static ShardingRouteEngine getDQLRouteEngineForShardingTable(final Shard
final ShardingConditions shardingConditions, final Collection<String> tableNames, final ConfigurationProperties props) {
if (isShardingStandardQuery(tableNames, shardingRule)) {
ShardingStandardRoutingEngine result = new ShardingStandardRoutingEngine(getLogicTableName(shardingConditions, tableNames), shardingConditions, props);
boolean needExecuteByCalcite = sqlStatementContext instanceof SelectStatementContext && ((SelectStatementContext) sqlStatementContext).isNeedExecuteByCalcite();
boolean needExecuteByCalcite = isNeedExecuteByCalcite(sqlStatementContext, shardingConditions);
if (!needExecuteByCalcite || result.route(shardingRule).isSingleRouting()) {
return result;
}
}
if (isShardingFederatedQuery(sqlStatementContext, tableNames, shardingRule)) {
return new ShardingFederatedRoutingEngine(tableNames, shardingConditions, props);
if (isShardingFederatedQuery(sqlStatementContext, tableNames, shardingRule, shardingConditions)) {
return new ShardingFederatedRoutingEngine(tableNames);
}
// TODO config for cartesian set
return new ShardingComplexRoutingEngine(tableNames, shardingConditions, props);
}

private static boolean isNeedExecuteByCalcite(final SQLStatementContext<?> sqlStatementContext, final ShardingConditions shardingConditions) {
if (!(sqlStatementContext instanceof SelectStatementContext)) {
return false;
}
SelectStatementContext selectStatementContext = (SelectStatementContext) sqlStatementContext;
if (selectStatementContext.getPaginationContext().isHasPagination()) {
return false;
}
return selectStatementContext.isNeedExecuteByCalcite() || (shardingConditions.isNeedMerge() && !shardingConditions.isSameShardingCondition());
}

private static String getLogicTableName(final ShardingConditions shardingConditions, final Collection<String> tableNames) {
return shardingConditions.getConditions().stream().flatMap(each -> each.getValues().stream())
.map(ShardingConditionValue::getTableName).findFirst().orElseGet(() -> tableNames.iterator().next());
Expand All @@ -200,12 +211,13 @@ private static boolean isShardingStandardQuery(final Collection<String> tableNam
return 1 == tableNames.size() && shardingRule.isAllShardingTables(tableNames) || shardingRule.isAllBindingTables(tableNames);
}

private static boolean isShardingFederatedQuery(final SQLStatementContext<?> sqlStatementContext, final Collection<String> tableNames, final ShardingRule shardingRule) {
private static boolean isShardingFederatedQuery(final SQLStatementContext<?> sqlStatementContext, final Collection<String> tableNames,
final ShardingRule shardingRule, final ShardingConditions shardingConditions) {
if (!(sqlStatementContext instanceof SelectStatementContext)) {
return false;
}
SelectStatementContext select = (SelectStatementContext) sqlStatementContext;
if (select.isNeedExecuteByCalcite()) {
if (isNeedExecuteByCalcite(sqlStatementContext, shardingConditions)) {
return true;
}
if ((!select.isContainsJoinQuery() && !select.isContainsSubquery()) || shardingRule.isAllTablesInSameDataSource(tableNames)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
package org.apache.shardingsphere.sharding.route.engine.type.federated;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditions;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.sharding.route.engine.type.ShardingRouteEngine;
import org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.unicast.ShardingUnicastRoutingEngine;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;

import java.util.Collection;
import java.util.Collections;
Expand All @@ -38,29 +36,22 @@ public final class ShardingFederatedRoutingEngine implements ShardingRouteEngine

private final Collection<String> logicTables;

private final ShardingConditions shardingConditions;

private final ConfigurationProperties properties;

@Override
public RouteContext route(final ShardingRule shardingRule) {
RouteContext result = new RouteContext();
for (String each : logicTables) {
RouteContext newRouteContext;
if (shardingRule.isShardingTable(each)) {
newRouteContext = new ShardingStandardRoutingEngine(each, shardingConditions, properties).route(shardingRule);
} else {
newRouteContext = new ShardingUnicastRoutingEngine(Collections.singletonList(each)).route(shardingRule);
}
fillRouteContext(result, newRouteContext);
fillRouteContext(result, shardingRule, each);
}
result.setFederated(true);
return result;
}

private void fillRouteContext(final RouteContext routeContext, final RouteContext newRouteContext) {
for (RouteUnit each : newRouteContext.getRouteUnits()) {
routeContext.putRouteUnit(each.getDataSourceMapper(), each.getTableMappers());
private void fillRouteContext(final RouteContext routeContext, final ShardingRule shardingRule, final String logicTableName) {
TableRule tableRule = shardingRule.getTableRule(logicTableName);
for (DataNode each : tableRule.getActualDataNodes()) {
RouteMapper dataSource = new RouteMapper(each.getDataSourceName(), each.getDataSourceName());
RouteMapper table = new RouteMapper(logicTableName, each.getTableName());
routeContext.putRouteUnit(dataSource, Collections.singletonList(table));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ private static Optional<ShardingStatementValidator> getDMLStatementValidator(fin
return Optional.of(new ShardingDeleteStatementValidator());
}
if (sqlStatement instanceof SelectStatement) {
return Optional.of(new ShardingSelectStatementValidator(shardingConditions));
return Optional.of(new ShardingSelectStatementValidator());
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void postValidate(final ShardingRule shardingRule, final SQLStatementCont
Optional<SubquerySegment> insertSelect = sqlStatementContext.getSqlStatement().getInsertSelect();
if (insertSelect.isPresent() && shardingConditions.isNeedMerge()) {
boolean singleRoutingOrSameShardingCondition = routeContext.isSingleRouting() || shardingConditions.isSameShardingCondition();
Preconditions.checkState(singleRoutingOrSameShardingCondition, "Sharding conditions must be same with others.");
Preconditions.checkState(singleRoutingOrSameShardingCondition, "Subquery sharding conditions must be same with primary query.");
}
String tableName = sqlStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue();
if (!routeContext.isSingleRouting() && !shardingRule.isBroadcastTable(tableName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.shardingsphere.sharding.route.engine.validator.dml.impl;

import com.google.common.base.Preconditions;
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.exception.ShardingSphereException;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.sharding.route.engine.condition.ShardingConditions;
import org.apache.shardingsphere.sharding.route.engine.validator.dml.ShardingDMLStatementValidator;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
Expand All @@ -36,8 +34,6 @@
@RequiredArgsConstructor
public final class ShardingSelectStatementValidator extends ShardingDMLStatementValidator<SelectStatement> {

private final ShardingConditions shardingConditions;

@Override
public void preValidate(final ShardingRule shardingRule, final SQLStatementContext<SelectStatement> sqlStatementContext,
final List<Object> parameters, final ShardingSphereSchema schema) {
Expand All @@ -49,9 +45,5 @@ public void preValidate(final ShardingRule shardingRule, final SQLStatementConte
@Override
public void postValidate(final ShardingRule shardingRule, final SQLStatementContext<SelectStatement> sqlStatementContext,
final RouteContext routeContext, final ShardingSphereSchema schema) {
if (!routeContext.isFederated() && shardingConditions.isNeedMerge()) {
boolean singleRoutingOrSameShardingCondition = routeContext.isSingleRouting() || shardingConditions.isSameShardingCondition();
Preconditions.checkState(singleRoutingOrSameShardingCondition, "Sharding conditions must be same with others.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.infra.binder.segment.table.TablesContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dcl.GrantStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
Expand All @@ -29,6 +30,7 @@
import org.apache.shardingsphere.sharding.route.engine.type.broadcast.ShardingInstanceBroadcastRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.broadcast.ShardingTableBroadcastRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.complex.ShardingComplexRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.federated.ShardingFederatedRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.ignore.ShardingIgnoreRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.standard.ShardingStandardRoutingEngine;
import org.apache.shardingsphere.sharding.route.engine.type.unicast.ShardingUnicastRoutingEngine;
Expand Down Expand Up @@ -62,11 +64,13 @@
import org.mockito.junit.MockitoJUnitRunner;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Properties;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -284,4 +288,20 @@ public void assertNewInstanceForShowColumnsWithTableRule() {
ShardingRouteEngine actual = ShardingRouteEngineFactory.newInstance(shardingRule, metaData, sqlStatementContext, shardingConditions, props);
assertThat(actual, instanceOf(ShardingUnicastRoutingEngine.class));
}

@Test
public void assertNewInstanceForSubqueryWithDifferentConditions() {
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
tableNames.add("t_order");
when(sqlStatementContext.getTablesContext().getTableNames()).thenReturn(tableNames);
when(sqlStatementContext.isNeedExecuteByCalcite()).thenReturn(false);
ShardingRule shardingRule = mock(ShardingRule.class, RETURNS_DEEP_STUBS);
when(shardingRule.getShardingRuleTableNames(tableNames)).thenReturn(tableNames);
when(shardingRule.isAllShardingTables(tableNames)).thenReturn(true);
when(shardingRule.getTableRule("t_order").getActualDatasourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
when(shardingConditions.isNeedMerge()).thenReturn(true);
when(shardingConditions.isSameShardingCondition()).thenReturn(false);
ShardingRouteEngine actual = ShardingRouteEngineFactory.newInstance(shardingRule, metaData, sqlStatementContext, shardingConditions, props);
assertThat(actual, instanceOf(ShardingFederatedRoutingEngine.class));
}
}
Loading

0 comments on commit ffd8576

Please sign in to comment.