Skip to content

Commit

Permalink
[CALCITE-3673] ListTransientTable should not leave tables in the schema
Browse files Browse the repository at this point in the history
[CALCITE-4054] RepeatUnion containing a Correlate with a transientScan on its RHS causes NPE
  • Loading branch information
rubenada committed Feb 18, 2022
1 parent e81cd20 commit 9c4f3bb
Show file tree
Hide file tree
Showing 13 changed files with 198 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@
package org.apache.calcite.adapter.enumerable;

import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.linq4j.function.Function0;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.RepeatUnion;
import org.apache.calcite.schema.TransientTable;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Util;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
import java.util.Objects;

/**
* Implementation of {@link RepeatUnion} in
Expand All @@ -43,14 +49,15 @@ public class EnumerableRepeatUnion extends RepeatUnion implements EnumerableRel
* Creates an EnumerableRepeatUnion.
*/
EnumerableRepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
super(cluster, traitSet, seed, iterative, all, iterationLimit);
RelNode seed, RelNode iterative, boolean all, int iterationLimit,
@Nullable RelOptTable transientTable) {
super(cluster, traitSet, seed, iterative, all, iterationLimit, transientTable);
}

@Override public EnumerableRepeatUnion copy(RelTraitSet traitSet, List<RelNode> inputs) {
assert inputs.size() == 2;
return new EnumerableRepeatUnion(getCluster(), traitSet,
inputs.get(0), inputs.get(1), all, iterationLimit);
inputs.get(0), inputs.get(1), all, iterationLimit, transientTable);
}

@Override public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
Expand All @@ -61,6 +68,32 @@ public class EnumerableRepeatUnion extends RepeatUnion implements EnumerableRel
RelNode seed = getSeedRel();
RelNode iteration = getIterativeRel();

Expression cleanUpFunctionExp = Expressions.constant(null);
if (transientTable != null) {
// root.getRootSchema().add(tableName, table);
Expression tableExp = implementor.stash(
Objects.requireNonNull(transientTable.unwrap(TransientTable.class)),
TransientTable.class);
String tableName =
transientTable.getQualifiedName().get(transientTable.getQualifiedName().size() - 1);
Expression tableNameExp = Expressions.constant(tableName, String.class);
builder.append(
Expressions.call(
Expressions.call(
implementor.getRootExpression(),
BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method),
BuiltInMethod.SCHEMA_PLUS_ADD_TABLE.method,
tableNameExp,
tableExp));
// root.getRootSchema().removeTable(tableName);
cleanUpFunctionExp = Expressions.lambda(Function0.class,
Expressions.call(
Expressions.call(
implementor.getRootExpression(),
BuiltInMethod.DATA_CONTEXT_GET_ROOT_SCHEMA.method),
BuiltInMethod.SCHEMA_PLUS_REMOVE_TABLE.method, tableNameExp));
}

Result seedResult = implementor.visitChild(this, 0, (EnumerableRel) seed, pref);
Result iterationResult = implementor.visitChild(this, 1, (EnumerableRel) iteration, pref);

Expand All @@ -78,7 +111,8 @@ public class EnumerableRepeatUnion extends RepeatUnion implements EnumerableRel
iterativeExp,
Expressions.constant(iterationLimit, int.class),
Expressions.constant(all, boolean.class),
Util.first(physType.comparer(), Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method)));
Util.first(physType.comparer(), Expressions.call(BuiltInMethod.IDENTITY_COMPARER.method)),
cleanUpFunctionExp);
builder.add(unionExp);

return implementor.result(physType, builder.toBlock());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ protected EnumerableRepeatUnionRule(Config config) {
convert(seedRel, seedRel.getTraitSet().replace(out)),
convert(iterativeRel, iterativeRel.getTraitSet().replace(out)),
union.all,
union.iterationLimit);
union.iterationLimit,
union.getTransientTable());
}
}
4 changes: 4 additions & 0 deletions core/src/main/java/org/apache/calcite/jdbc/CalciteSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,10 @@ CalciteSchema calciteSchema() {
CalciteSchema.this.add(name, table);
}

@Override public boolean removeTable(String name) {
return CalciteSchema.this.removeTable(name);
}

@Override public void add(String name, Function function) {
CalciteSchema.this.add(name, function);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ public static MySchemaPlus create(Path path) {
throw new UnsupportedOperationException();
}

@Override public boolean removeTable(String name) {
throw new UnsupportedOperationException();
}

@Override public void add(String name,
org.apache.calcite.schema.Function function) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ private static class SpoolFactoryImpl implements SpoolFactory {
public interface RepeatUnionFactory {
/** Creates a {@link RepeatUnion}. */
RelNode createRepeatUnion(RelNode seed, RelNode iterative, boolean all,
int iterationLimit);
int iterationLimit, RelOptTable table);
}

/**
Expand All @@ -619,8 +619,8 @@ RelNode createRepeatUnion(RelNode seed, RelNode iterative, boolean all,
*/
private static class RepeatUnionFactoryImpl implements RepeatUnionFactory {
@Override public RelNode createRepeatUnion(RelNode seed, RelNode iterative,
boolean all, int iterationLimit) {
return LogicalRepeatUnion.create(seed, iterative, all, iterationLimit);
boolean all, int iterationLimit, RelOptTable table) {
return LogicalRepeatUnion.create(seed, iterative, all, iterationLimit, table);
}
}

Expand Down
23 changes: 21 additions & 2 deletions core/src/main/java/org/apache/calcite/rel/core/RepeatUnion.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@

import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.schema.TransientTable;
import org.apache.calcite.util.Util;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
import java.util.Objects;

/**
* Relational expression that computes a repeat union (recursive union in SQL
Expand All @@ -40,7 +45,7 @@
*
* <li>Evaluate the right input (i.e., iterative relational expression) over and
* over until it produces no more results (or until an optional maximum number
* of iterations is reached). For UNION (but not UNION ALL), discard
* of iterations is reached). For UNION (but not UNION ALL), discard
* duplicated results.
* </ul>
*
Expand All @@ -61,12 +66,22 @@ public abstract class RepeatUnion extends BiRel {
*/
public final int iterationLimit;

/**
* Transient table where repeat union's intermediate results will be stored (optional).
*/
protected final @Nullable RelOptTable transientTable;

//~ Constructors -----------------------------------------------------------
protected RepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
RelNode seed, RelNode iterative, boolean all, int iterationLimit,
@Nullable RelOptTable transientTable) {
super(cluster, traitSet, seed, iterative);
this.iterationLimit = iterationLimit;
this.all = all;
this.transientTable = transientTable;
if (transientTable != null) {
Objects.requireNonNull(transientTable.unwrap(TransientTable.class));
}
}

@Override public double estimateRowCount(RelMetadataQuery mq) {
Expand Down Expand Up @@ -95,6 +110,10 @@ public RelNode getIterativeRel() {
return right;
}

public @Nullable RelOptTable getTransientTable() {
return transientTable;
}

@Override protected RelDataType deriveRowType() {
final List<RelDataType> inputRowTypes =
Util.transform(getInputs(), RelNode::getRowType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.RepeatUnion;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;

/**
Expand All @@ -37,22 +40,24 @@ public class LogicalRepeatUnion extends RepeatUnion {

//~ Constructors -----------------------------------------------------------
private LogicalRepeatUnion(RelOptCluster cluster, RelTraitSet traitSet,
RelNode seed, RelNode iterative, boolean all, int iterationLimit) {
super(cluster, traitSet, seed, iterative, all, iterationLimit);
RelNode seed, RelNode iterative, boolean all, int iterationLimit,
@Nullable RelOptTable transientTable) {
super(cluster, traitSet, seed, iterative, all, iterationLimit, transientTable);
}

/** Creates a LogicalRepeatUnion. */
public static LogicalRepeatUnion create(RelNode seed, RelNode iterative,
boolean all) {
return create(seed, iterative, all, -1);
boolean all, @Nullable RelOptTable transientTable) {
return create(seed, iterative, all, -1, transientTable);
}

/** Creates a LogicalRepeatUnion. */
public static LogicalRepeatUnion create(RelNode seed, RelNode iterative,
boolean all, int iterationLimit) {
boolean all, int iterationLimit, @Nullable RelOptTable transientTable) {
RelOptCluster cluster = seed.getCluster();
RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
return new LogicalRepeatUnion(cluster, traitSet, seed, iterative, all, iterationLimit);
return new LogicalRepeatUnion(cluster, traitSet, seed, iterative, all, iterationLimit,
transientTable);
}

//~ Methods ----------------------------------------------------------------
Expand All @@ -62,6 +67,6 @@ public static LogicalRepeatUnion create(RelNode seed, RelNode iterative,
assert traitSet.containsIfApplicable(Convention.NONE);
assert inputs.size() == 2;
return new LogicalRepeatUnion(getCluster(), traitSet,
inputs.get(0), inputs.get(1), all, iterationLimit);
inputs.get(0), inputs.get(1), all, iterationLimit, transientTable);
}
}
7 changes: 7 additions & 0 deletions core/src/main/java/org/apache/calcite/schema/SchemaPlus.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ public interface SchemaPlus extends Schema {
/** Adds a table to this schema. */
void add(String name, Table table);

/** Removes a table from this schema, used e.g. to clean-up temporary tables. */
default boolean removeTable(String name) {
// Default implementation provided for backwards compatibility, to be removed before 2.0
return false;
}


/** Adds a function to this schema. */
void add(String name, Function function);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;

import static java.util.Objects.requireNonNull;

/**
* {@link TransientTable} backed by a Java list. It will be automatically added to the
* current schema when {@link #scan(DataContext)} method gets called.
Expand All @@ -61,7 +59,9 @@
public class ListTransientTable extends AbstractQueryableTable
implements TransientTable, ModifiableTable, ScannableTable {
private static final Type TYPE = Object[].class;
@SuppressWarnings("rawtypes")
private final List rows = new ArrayList();
@SuppressWarnings({"unused", "FieldCanBeLocal"})
private final String name;
private final RelDataType protoRowType;

Expand All @@ -84,20 +84,19 @@ public ListTransientTable(String name, RelDataType rowType) {
updateColumnList, sourceExpressionList, flattened);
}

@SuppressWarnings("rawtypes")
@Override public Collection getModifiableCollection() {
return rows;
}

@Override public Enumerable<@Nullable Object[]> scan(DataContext root) {
// add the table into the schema, so that it is accessible by any potential operator
requireNonNull(root.getRootSchema(), "root.getRootSchema()")
.add(name, this);

final AtomicBoolean cancelFlag = DataContext.Variable.CANCEL_FLAG.get(root);

return new AbstractEnumerable<@Nullable Object[]>() {
@Override public Enumerator<@Nullable Object[]> enumerator() {
return new Enumerator<@Nullable Object[]>() {
@SuppressWarnings({"rawtypes", "unchecked"})
private final List list = new ArrayList(rows);
private int i = -1;

Expand Down Expand Up @@ -129,7 +128,7 @@ public ListTransientTable(String name, RelDataType rowType) {
}

@Override public Expression getExpression(SchemaPlus schema, String tableName,
Class clazz) {
@SuppressWarnings("rawtypes") Class clazz) {
return Schemas.tableExpression(schema, elementType, tableName, clazz);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2709,7 +2709,7 @@ public RelBuilder repeatUnion(String tableName, boolean all, int iterationLimit)
RelNode seed = tableSpool(Spool.Type.LAZY, Spool.Type.LAZY, finder.relOptTable).build();
RelNode repeatUnion =
struct.repeatUnionFactory.createRepeatUnion(seed, iterative, all,
iterationLimit);
iterationLimit, finder.relOptTable);
return push(repeatUnion);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.schema.Schemas;
import org.apache.calcite.schema.Table;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlJsonConstructorNullClause;
import org.apache.calcite.sql.SqlJsonQueryEmptyOrErrorBehavior;
Expand Down Expand Up @@ -156,6 +157,8 @@ public enum BuiltInMethod {
REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class),
SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class),
SCHEMA_GET_TABLE(Schema.class, "getTable", String.class),
SCHEMA_PLUS_ADD_TABLE(SchemaPlus.class, "add", String.class, Table.class),
SCHEMA_PLUS_REMOVE_TABLE(SchemaPlus.class, "removeTable", String.class),
SCHEMA_PLUS_UNWRAP(SchemaPlus.class, "unwrap", Class.class),
SCHEMAS_ENUMERABLE_SCANNABLE(Schemas.class, "enumerable",
ScannableTable.class, DataContext.class),
Expand Down Expand Up @@ -240,7 +243,7 @@ public enum BuiltInMethod {
UNION(ExtendedEnumerable.class, "union", Enumerable.class),
CONCAT(ExtendedEnumerable.class, "concat", Enumerable.class),
REPEAT_UNION(EnumerableDefaults.class, "repeatUnion", Enumerable.class,
Enumerable.class, int.class, boolean.class, EqualityComparer.class),
Enumerable.class, int.class, boolean.class, EqualityComparer.class, Function0.class),
MERGE_UNION(EnumerableDefaults.class, "mergeUnion", List.class, Function1.class,
Comparator.class, boolean.class, EqualityComparer.class),
LAZY_COLLECTION_SPOOL(EnumerableDefaults.class, "lazyCollectionSpool", Collection.class,
Expand Down
Loading

0 comments on commit 9c4f3bb

Please sign in to comment.