Skip to content

Commit

Permalink
fix: 修复无法在having中使用预编译参数
Browse files Browse the repository at this point in the history
  • Loading branch information
zhou-hao committed May 5, 2023
1 parent 8f0326f commit 4717dca
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.hswebframework.web.crud.query;

import lombok.Getter;
import lombok.SneakyThrows;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.values.ValuesStatement;
Expand All @@ -18,6 +18,7 @@
import org.hswebframework.ezorm.rdb.operator.builder.fragments.PrepareSqlFragments;
import org.hswebframework.ezorm.rdb.operator.builder.fragments.SqlFragments;
import org.hswebframework.web.api.crud.entity.QueryParamEntity;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.*;
Expand Down Expand Up @@ -435,10 +436,11 @@ class SimpleQueryRefactor implements QueryRefactor, SelectVisitor {
private String columns;

private String where;

private int prefixParameters;
private String orderBy;

private String suffix;
private int suffixParameters;

private boolean fastCount = true;

Expand All @@ -451,7 +453,6 @@ private void initColumns(StringBuilder columns) {
int idx = 0;
Dialect dialect = database.getMetadata().getDialect();


for (Column column : select.columnList) {
if (idx++ > 0) {
columns.append(",");
Expand Down Expand Up @@ -487,49 +488,53 @@ public void visit(PlainSelect plainSelect) {

initColumns(columns);

// prefix.append(getStringList(plainSelect.getSelectItems()));

if (null != plainSelect.getFromItem()) {
if (plainSelect.getFromItem() != null) {
from.append("FROM ");

from.append(plainSelect.getFromItem());
}

if (plainSelect.getJoins() != null) {
PrepareStatementVisitor visitor = new PrepareStatementVisitor();
for (net.sf.jsqlparser.statement.select.Join join : plainSelect.getJoins()) {
if (join.isSimple()) {
from.append(", ").append(join);
} else {
from.append(" ").append(join);
}
if (null != join.getOnExpressions()) {
for (Expression onExpression : join.getOnExpressions()) {
onExpression.accept(visitor);
}
}
}
prefixParameters += visitor.parameterSize;
}

if (null != plainSelect.getWhere()) {
if (plainSelect.getWhere() != null) {
PrepareStatementVisitor visitor = new PrepareStatementVisitor();
plainSelect.getWhere().accept(visitor);
prefixParameters += visitor.parameterSize;
where = plainSelect.getWhere().toString();
}

if (plainSelect.getOrderByElements() != null) {
orderBy = orderByToString(plainSelect.isOracleSiblings(), plainSelect.getOrderByElements());
}

if (null != plainSelect.getGroupBy()) {
if (plainSelect.getGroupBy() != null) {
fastCount = false;
suffix.append(' ').append(plainSelect.getGroupBy());
}

suffix.append(' ');
if (null != plainSelect.getHaving()) {

if (plainSelect.getHaving() != null) {
PrepareStatementVisitor visitor = new PrepareStatementVisitor();
plainSelect.getHaving().accept(visitor);
suffixParameters = visitor.parameterSize;
suffix.append(" HAVING ").append(plainSelect.getHaving());
}

// if (plainSelect.getLimit() != null) {
// suffix.append(plainSelect.getLimit());
// }
// if (plainSelect.getOffset() != null) {
// suffix.append(plainSelect.getOffset());
// }

this.columns = columns.toString();
this.from = from.toString();
this.suffix = suffix.toString();
Expand Down Expand Up @@ -564,26 +569,48 @@ public void visit(ValuesStatement aThis) {

}

public Object[] getPrefixParameters(Object... args) {
if (prefixParameters == 0) {
return new Object[0];
}
Assert.isTrue(args.length >= prefixParameters,
"Illegal prepare statement parameter size, expect: " + prefixParameters + ", actual: " + args.length);

return Arrays.copyOfRange(args, 0, prefixParameters);
}

public Object[] getSuffixParameters(Object... args) {
if (suffixParameters == 0) {
return new Object[0];
}
Assert.isTrue(args.length >= suffixParameters + prefixParameters,
"Illegal prepare statement parameter size, expect: " + suffixParameters + prefixParameters + ", actual: " + args.length);

return Arrays.copyOfRange(args, prefixParameters, suffixParameters + prefixParameters);
}

@Override
public SqlRequest refactor(QueryParamEntity param, Object... args) {
PrepareSqlFragments sql = PrepareSqlFragments
.of("SELECT", args)
.of("SELECT")
.addSql(columns)
.addSql(from);
.addSql(from)
.addParameter(getPrefixParameters(args));

appendWhere(sql, param);

appendOrderBy(sql, param);

sql.addSql(suffix);
sql.addSql(suffix)
.addParameter(getSuffixParameters(args));

return sql.toRequest();
}


@Override
public SqlRequest refactorCount(QueryParamEntity param, Object... args) {
PrepareSqlFragments sql = PrepareSqlFragments.of("SELECT", args);
PrepareSqlFragments sql = PrepareSqlFragments
.of("SELECT", getPrefixParameters(args));

if (fastCount) {
sql.addSql("count(1) as _total");
Expand All @@ -603,8 +630,9 @@ public SqlRequest refactorCount(QueryParamEntity param, Object... args) {
.addSql(")");
}


return sql.toRequest();
return sql
.addParameter(getSuffixParameters(args))
.toRequest();
}


Expand Down Expand Up @@ -639,6 +667,18 @@ private void appendWhere(PrepareSqlFragments sql, QueryParamEntity param) {

}


@Getter
static class PrepareStatementVisitor extends ExpressionVisitorAdapter {
private int parameterSize;

@Override
public void visit(JdbcParameter parameter) {
parameterSize++;
super.visit(parameter);
}
}

private interface QueryRefactor {

SqlRequest refactor(QueryParamEntity param, Object... args);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ class DefaultQueryHelperTest {
private DatabaseOperator database;


@Test
public void testGroup() {
DefaultQueryHelper helper = new DefaultQueryHelper(database);

helper.select("select name,count(1) _total from s_test group by name having count(1) > ? ", 0)
.where(dsl -> dsl
.is("age", "31"))
.fetchPaged()
.doOnNext(v -> System.out.println(JSON.toJSONString(v, SerializerFeature.PrettyFormat)))
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
}

@Test
public void testNative() {
database.dml()
Expand All @@ -49,8 +63,8 @@ public void testNative() {
DefaultQueryHelper helper = new DefaultQueryHelper(database);

helper.select("select e.*,t.id as \"id\" from s_test t " +
"left join s_test_event e on e.id = t.id" +
" where t.age = ? order by t.age desc", 20)
"left join s_test_event e on e.id = t.id " +
"where t.age = ? order by t.age desc", 20)
.where(dsl -> dsl
.is("e.id", "helper_testNative")
.is("t.age", "20"))
Expand Down

0 comments on commit 4717dca

Please sign in to comment.