Skip to content

Commit

Permalink
revise InsertStatement
Browse files Browse the repository at this point in the history
  • Loading branch information
tristaZero committed Jun 4, 2019
1 parent dacac50 commit fff7f86
Showing 1 changed file with 43 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
import org.apache.shardingsphere.core.parse.sql.segment.dml.predicate.PredicateSegment;
import org.apache.shardingsphere.core.parse.sql.statement.SQLStatement;
import org.apache.shardingsphere.core.parse.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.parse.sql.token.impl.InsertColumnsToken;
import org.apache.shardingsphere.core.parse.sql.token.impl.InsertValuesToken;
import org.apache.shardingsphere.core.rule.ShardingRule;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;

/**
* Insert values filler for sharding.
Expand All @@ -65,8 +63,7 @@ public void fill(final InsertValuesSegment sqlSegment, final SQLStatement sqlSta
InsertValue insertValue = new InsertValue(sqlSegment.getValues());
insertStatement.getValues().add(insertValue);
insertStatement.setParametersIndex(insertStatement.getParametersIndex() + insertValue.getParametersCount());
fillWithInsertValuesToken(sqlSegment, insertStatement);
reviseInsertColumnNames(sqlSegment, insertStatement);
reviseInsertStatement(insertStatement, sqlSegment);
}

private Iterator<String> getColumnNames(final InsertValuesSegment sqlSegment, final InsertStatement insertStatement) {
Expand All @@ -86,7 +83,48 @@ private void fillShardingCondition(final AndCondition andCondition, final String
}
}

private void fillWithInsertValuesToken(final InsertValuesSegment sqlSegment, final InsertStatement insertStatement) {
private void reviseInsertStatement(final InsertStatement insertStatement, final InsertValuesSegment sqlSegment) {
reviseInsertColumnNames(insertStatement, sqlSegment);
setNeededToAppendGeneratedKey(insertStatement);
setNeededToAppendAssistedColumns(insertStatement);
fillWithInsertValuesToken(insertStatement, sqlSegment);
}

private void reviseInsertColumnNames(final InsertStatement insertStatement, final InsertValuesSegment sqlSegment) {
Collection<String> insertColumns = new ArrayList<>(insertStatement.getColumnNames());
insertColumns.removeAll(getAssistedQueryColumns(insertStatement));
Optional<String> generateKeyColumnName = shardingRule.findGenerateKeyColumnName(insertStatement.getTables().getSingleTableName());
if (insertStatement.getColumnNames().size() != sqlSegment.getValues().size() && generateKeyColumnName.isPresent()) {
insertColumns.remove(generateKeyColumnName.get());
}
insertStatement.getColumnNames().clear();
insertStatement.getColumnNames().addAll(insertColumns);
}

private void setNeededToAppendGeneratedKey(final InsertStatement insertStatement) {
Optional<String> generateKeyColumnName = shardingRule.findGenerateKeyColumnName(insertStatement.getTables().getSingleTableName());
if (generateKeyColumnName.isPresent() && !insertStatement.getColumnNames().contains(generateKeyColumnName.get())) {
insertStatement.setNeededToAppendGeneratedKey(true);
}
}

private void setNeededToAppendAssistedColumns(final InsertStatement insertStatement) {
Collection<String> assistedQueryColumns = getAssistedQueryColumns(insertStatement);
if (!assistedQueryColumns.isEmpty()) {
insertStatement.setNeededToAppendAssistedColumns(true);
}
}

private Collection<String> getAssistedQueryColumns(final InsertStatement insertStatement) {
Collection<String> result = new ArrayList<>();
Collection<String> assistedQueryColumns = shardingRule.getEncryptRule().getEncryptorEngine().getAssistedQueryColumns(insertStatement.getTables().getSingleTableName());
if (!assistedQueryColumns.isEmpty()) {
result.addAll(assistedQueryColumns);
}
return result;
}

private void fillWithInsertValuesToken(final InsertStatement insertStatement, final InsertValuesSegment sqlSegment) {
Optional<InsertValuesToken> insertValuesToken = insertStatement.findSQLToken(InsertValuesToken.class);
if (insertValuesToken.isPresent()) {
int startIndex = insertValuesToken.get().getStartIndex() < sqlSegment.getStartIndex() ? insertValuesToken.get().getStartIndex() : sqlSegment.getStartIndex();
Expand All @@ -97,27 +135,4 @@ private void fillWithInsertValuesToken(final InsertValuesSegment sqlSegment, fin
insertStatement.getSQLTokens().add(new InsertValuesToken(sqlSegment.getStartIndex(), sqlSegment.getStopIndex()));
}
}

private void reviseInsertColumnNames(final InsertValuesSegment sqlSegment, final InsertStatement insertStatement) {
Collection<String> result = new ArrayList<>(insertStatement.getColumnNames());
result.removeAll(shardingRule.getEncryptRule().getEncryptorEngine().getAssistedQueryColumns(insertStatement.getTables().getSingleTableName()));
Optional<String> generateKeyColumnName = shardingRule.findGenerateKeyColumnName(insertStatement.getTables().getSingleTableName());
if (insertStatement.getColumnNames().size() != sqlSegment.getValues().size() && generateKeyColumnName.isPresent()) {
result.remove(generateKeyColumnName.get());
reviseInsertColumnsToken(insertStatement, generateKeyColumnName.get(), result);
}
insertStatement.getColumnNames().clear();
insertStatement.getColumnNames().addAll(result);
}

private void reviseInsertColumnsToken(final InsertStatement insertStatement, final String generateKeyColumnName, final Collection<String> columnNames) {
Optional<InsertColumnsToken> insertColumnsToken = insertStatement.findSQLToken(InsertColumnsToken.class);
Collection<String> assistedColumns = new LinkedList<>(insertColumnsToken.get().getColumns());
assistedColumns.removeAll(columnNames);
assistedColumns.remove(generateKeyColumnName);
insertColumnsToken.get().getColumns().clear();
insertColumnsToken.get().getColumns().addAll(columnNames);
insertColumnsToken.get().getColumns().add(generateKeyColumnName);
insertColumnsToken.get().getColumns().addAll(assistedColumns);
}
}

0 comments on commit fff7f86

Please sign in to comment.