Skip to content

Commit

Permalink
Calculate statistics for TABLESAMPLE
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed May 27, 2021
1 parent 03f3ad5 commit 96c7b7a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
49 changes: 49 additions & 0 deletions core/trino-main/src/main/java/io/trino/cost/SampleStatsRule.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;

import io.trino.Session;
import io.trino.matching.Pattern;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.SampleNode;

import java.util.Optional;

import static io.trino.sql.planner.plan.Patterns.sample;

public class SampleStatsRule
extends SimpleStatsRule<SampleNode>
{
private static final Pattern<SampleNode> PATTERN = sample();

public SampleStatsRule(StatsNormalizer normalizer)
{
super(normalizer);
}

@Override
public Pattern<SampleNode> getPattern()
{
return PATTERN;
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(SampleNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
PlanNodeStatsEstimate calculatedStats = sourceStats.mapOutputRowCount(outputRowCount -> outputRowCount * node.getSampleRatio());
return Optional.of(calculatedStats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public static StatsCalculator createNewStatsCalculator(Metadata metadata, TypeAn
rules.add(new AssignUniqueIdStatsRule());
rules.add(new SemiJoinStatsRule());
rules.add(new RowNumberStatsRule(normalizer));
rules.add(new SampleStatsRule(normalizer));
rules.add(new SortStatsRule());

return new ComposableStatsCalculator(rules.build());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;

import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.SampleNode;
import org.testng.annotations.Test;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.Double.POSITIVE_INFINITY;

public class TestSampleStatsRule
extends BaseStatsCalculatorTest
{
@Test
public void testStatsForSampleNode()
{
tester()
.assertStatsFor(pb -> {
Symbol a = pb.symbol("a", BIGINT);
Symbol b = pb.symbol("b", DOUBLE);
return pb.sample(.33, SampleNode.Type.BERNOULLI, pb.values(a, b));
})
.withSourceStats(PlanNodeStatsEstimate.builder()
.setOutputRowCount(100)
.addSymbolStatistics(
new Symbol("a"),
SymbolStatsEstimate.builder()
.setDistinctValuesCount(20)
.setNullsFraction(0.3)
.setLowValue(1)
.setHighValue(30)
.build())
.addSymbolStatistics(
new Symbol("b"),
SymbolStatsEstimate.builder()
.setDistinctValuesCount(40)
.setNullsFraction(0.6)
.setLowValue(13.5)
.setHighValue(POSITIVE_INFINITY)
.build())
.build())
.check(check -> check
.outputRowsCount(33)
.symbolStats("a", assertion -> assertion
.dataSizeUnknown()
.distinctValuesCount(20)
.nullsFraction(0.3)
.lowValue(1)
.highValue(30))
.symbolStats("b", assertion -> assertion
.dataSizeUnknown()
.distinctValuesCount(23.1)
.nullsFraction(0.3)
.lowValue(13.5)
.highValueUnknown()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import static io.trino.SystemSessionProperties.COLLECT_PLAN_STATISTICS_FOR_ALL_QUERIES;
import static io.trino.SystemSessionProperties.PREFER_PARTIAL_AGGREGATION;
import static io.trino.plugin.tpch.TpchConnectorFactory.TPCH_COLUMN_NAMING_PROPERTY;
import static io.trino.testing.assertions.Assert.assertEventually;
import static io.trino.testing.statistics.MetricComparisonStrategies.absoluteError;
import static io.trino.testing.statistics.MetricComparisonStrategies.defaultTolerance;
import static io.trino.testing.statistics.MetricComparisonStrategies.noError;
Expand Down Expand Up @@ -176,4 +177,17 @@ public void testSort()
statisticsAssertion.check("SELECT * FROM nation ORDER BY n_nationkey",
checks -> checks.estimate(OUTPUT_ROW_COUNT, noError()));
}

@Test
public void testTablesample()
{
statisticsAssertion.check("SELECT * FROM orders TABLESAMPLE BERNOULLI (42)",
checks -> checks.noEstimate(OUTPUT_ROW_COUNT)); // BERNOULLI sample gets converted to a `rand() < 0.42` filter and does not get estimated currently

// Using eventual assertion because TABLESAMPLE SYSTEM has high variance of number of result rows being returned, when calculating the actual value.
assertEventually(() -> {
statisticsAssertion.check("SELECT * FROM orders TABLESAMPLE SYSTEM (42)",
checks -> checks.estimate(OUTPUT_ROW_COUNT, relativeError(.3)));
});
}
}

0 comments on commit 96c7b7a

Please sign in to comment.