-
-
Save kokosing/28328984617ce0ad376d0c7338f40203 to your computer and use it in GitHub Desktop.
reorder joins debug hack utility
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit c158afc2ce2e1ac05b0fbe4f82edc9011e37057c | |
Author: Grzegorz Kokosiński <[email protected]> | |
Date: Fri Jul 14 12:33:27 2017 +0200 | |
Reorder joins debug infra | |
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java | |
index 9e1f2ad..1f25125 100644 | |
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java | |
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java | |
@@ -22,6 +22,7 @@ import com.facebook.presto.cost.CostComparator; | |
import com.facebook.presto.cost.JoinNodeCachingStatsCalculator; | |
import com.facebook.presto.cost.PlanNodeCostEstimate; | |
import com.facebook.presto.cost.StatsCalculator; | |
+import com.facebook.presto.spi.type.Type; | |
import com.facebook.presto.sql.analyzer.FeaturesConfig; | |
import com.facebook.presto.sql.planner.EqualityInference; | |
import com.facebook.presto.sql.planner.PlanNodeIdAllocator; | |
@@ -33,16 +34,25 @@ import com.facebook.presto.sql.planner.plan.FilterNode; | |
import com.facebook.presto.sql.planner.plan.JoinNode; | |
import com.facebook.presto.sql.planner.plan.PlanNode; | |
import com.facebook.presto.sql.planner.plan.ProjectNode; | |
+import com.facebook.presto.sql.planner.plan.TableScanNode; | |
import com.facebook.presto.sql.tree.ComparisonExpression; | |
import com.facebook.presto.sql.tree.Expression; | |
import com.facebook.presto.sql.tree.SymbolReference; | |
import com.google.common.annotations.VisibleForTesting; | |
+import com.google.common.base.Charsets; | |
import com.google.common.collect.ImmutableList; | |
+import com.google.common.collect.ImmutableMap; | |
import com.google.common.collect.ImmutableSet; | |
import com.google.common.collect.Ordering; | |
import com.google.common.collect.Sets; | |
+import com.google.common.collect.TreeTraverser; | |
+import com.google.common.io.Files; | |
import io.airlift.log.Logger; | |
+import java.io.BufferedWriter; | |
+import java.io.File; | |
+import java.io.FileNotFoundException; | |
+import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.HashMap; | |
import java.util.List; | |
@@ -50,6 +60,8 @@ import java.util.Map; | |
import java.util.Objects; | |
import java.util.Optional; | |
import java.util.Set; | |
+import java.util.concurrent.atomic.AtomicInteger; | |
+import java.util.function.Function; | |
import java.util.stream.IntStream; | |
import java.util.stream.Stream; | |
@@ -65,6 +77,7 @@ import static com.facebook.presto.sql.planner.EqualityInference.createEqualityIn | |
import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode; | |
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; | |
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; | |
+import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; | |
import static com.facebook.presto.sql.planner.plan.Assignments.identity; | |
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; | |
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; | |
@@ -77,7 +90,9 @@ import static com.google.common.base.Predicates.in; | |
import static com.google.common.collect.ImmutableList.toImmutableList; | |
import static com.google.common.collect.ImmutableSet.toImmutableSet; | |
import static com.google.common.collect.Iterables.getOnlyElement; | |
+import static java.lang.String.format; | |
import static java.util.Objects.requireNonNull; | |
+import static java.util.stream.Collectors.joining; | |
import static java.util.stream.StreamSupport.stream; | |
public class ReorderJoins | |
@@ -115,8 +130,15 @@ public class ReorderJoins | |
return Optional.empty(); | |
} | |
- Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(statsCalculator)), new CachingCostCalculator(costCalculator)); | |
- JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); | |
+ StatsCalculator statsCalculator = new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(this.statsCalculator)); | |
+ CachingCostCalculator costCalculator = new CachingCostCalculator(this.costCalculator); | |
+ Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, statsCalculator, costCalculator); | |
+ Debugger debugger = new Debugger(multiJoinNode, joinCachingStatsLookup, session, symbolAllocator.getTypes()); | |
+ JoinEnumerator joinEnumerator = new JoinEnumerator(debugger, idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator); | |
+ JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); | |
+ | |
+ debugger.debug(result.cost, result.getPlanNode().get(), "best of the best"); | |
+ debugger.close(); | |
return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode(); | |
} | |
@@ -131,10 +153,12 @@ public class ReorderJoins | |
private final Expression allFilter; | |
private final SymbolAllocator symbolAllocator; | |
private final Lookup lookup; | |
+ private final Debugger debugger; | |
@VisibleForTesting | |
- JoinEnumerator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator) | |
+ JoinEnumerator(Debugger debugger, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator) | |
{ | |
+ this.debugger = debugger; | |
requireNonNull(idAllocator, "idAllocator is null"); | |
requireNonNull(symbolAllocator, "symbolAllocator is null"); | |
requireNonNull(session, "session is null"); | |
@@ -248,7 +272,8 @@ public class ReorderJoins | |
// create equality inference on available symbols | |
// TODO: make generateEqualitiesPartitionedBy take left and right scope | |
- List<Expression> joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); | |
+ List<Expression> joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)) | |
+ .getScopeEqualities(); | |
EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[joinEqualities.size()])); | |
joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); | |
@@ -318,7 +343,7 @@ public class ReorderJoins | |
joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), | |
Optional.empty(), | |
Optional.empty(), | |
- Optional.empty())); | |
+ Optional.empty()), leftSources, rightSources); | |
if (!joinOutputSymbols.equals(sortedOutputSymbols)) { | |
PlanNode resultNode = new ProjectNode(idAllocator.getNextId(), result.planNode.get(), identity(sortedOutputSymbols)); | |
@@ -364,27 +389,44 @@ public class ReorderJoins | |
return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip(); | |
} | |
- private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) | |
+ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode, Set<PlanNode> leftSources, Set<PlanNode> rightSources) | |
{ | |
+ String leftId = debugger.id(leftSources); | |
+ String rightId = debugger.id(rightSources); | |
+ | |
List<JoinEnumerationResult> possibleJoinNodes = new ArrayList<>(); | |
FeaturesConfig.JoinDistributionType joinDistributionType = getJoinDistributionType(session); | |
if (joinDistributionType.canRepartition() && !joinNode.isCrossJoin()) { | |
JoinNode node = joinNode.withDistributionType(PARTITIONED); | |
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); | |
+ JoinEnumerationResult e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)); | |
+ possibleJoinNodes.add(e); | |
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "repartition"); | |
+ | |
node = node.flipChildren(); | |
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); | |
+ e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)); | |
+ possibleJoinNodes.add(e); | |
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "repartition flipped"); | |
} | |
if (joinDistributionType.canReplicate()) { | |
JoinNode node = joinNode.withDistributionType(REPLICATED); | |
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); | |
+ JoinEnumerationResult e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)); | |
+ possibleJoinNodes.add(e); | |
node = node.flipChildren(); | |
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); | |
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "replicated"); | |
+ | |
+ e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)); | |
+ possibleJoinNodes.add(e); | |
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "replicated flipped"); | |
} | |
if (possibleJoinNodes.stream().anyMatch(result -> result.cost.isUnknown())) { | |
return UNKNOWN_COST_RESULT; | |
} | |
- return resultOrdering.min(possibleJoinNodes); | |
+ JoinEnumerationResult best = resultOrdering.min(possibleJoinNodes); | |
+ debugger.debug(leftId, rightId, best.cost, best.planNode.get(), "best"); | |
+ | |
+ return best; | |
} | |
+ | |
} | |
@VisibleForTesting | |
@@ -413,4 +455,100 @@ public class ReorderJoins | |
return cost; | |
} | |
} | |
+ | |
+ private static class Debugger implements AutoCloseable | |
+ { | |
+ private static final AtomicInteger debuggerIdCounter = new AtomicInteger(); | |
+ | |
+ private final BufferedWriter writer; | |
+ private final Map<PlanNode, String> nodeNames; | |
+ private final Map<PlanNode, Integer> nodeIds = new HashMap<>(); | |
+ private final AtomicInteger nodeIdCounter = new AtomicInteger(); | |
+ | |
+ public Debugger(MultiJoinNode multiJoinNode, Lookup lookup, Session session, Map<Symbol, Type> types) | |
+ { | |
+ try { | |
+ this.writer = Files.newWriter(new File("/tmp/reorder_joins.log_" + debuggerIdCounter.incrementAndGet()), Charsets.US_ASCII); | |
+ } | |
+ catch (FileNotFoundException e) { | |
+ throw new RuntimeException(e); | |
+ } | |
+ | |
+ nodeNames = multiJoinNode.getSources().stream() | |
+ .flatMap(node -> Stream.of(node, lookup.resolve(node))) | |
+ .collect(ImmutableMap.toImmutableMap(Function.identity(), source -> { | |
+ TableScanNode tableScanNode = (TableScanNode) searchFrom(source, lookup) | |
+ .where(TableScanNode.class::isInstance) | |
+ .findFirst() | |
+ .get(); | |
+ return tableScanNode.getTable().toString(); })); | |
+ | |
+ multiJoinNode.getSources().stream() | |
+ .flatMap(source -> TreeTraverser.<PlanNode>using(node -> lookup.resolve(node).getSources()).preOrderTraversal(source).stream()) | |
+ .forEach(node -> debug(lookup.getCumulativeCost(node, session, types), node, nodeNames.get(node))); | |
+ } | |
+ | |
+ private String id(Set<PlanNode> leftSources1) | |
+ { | |
+ return leftSources1.stream() | |
+ .map(nodeNames::get) | |
+ .map(s -> s.replaceAll("tpch:", "")) | |
+ .map(s -> s.replaceAll(":sf.*", "")) | |
+ .sorted() | |
+ .collect(joining(",")); | |
+ } | |
+ | |
+ private void debug(String leftId, String rightId, PlanNodeCostEstimate cost, PlanNode planNode, String comment) | |
+ { | |
+ try { | |
+ writer.append(format("%s - %s [%s] is (cpu: %.0f, mem: %.0f, net: %.0f) (%s)", | |
+ leftId, rightId, comment, cost.getCpuCost(), cost.getMemoryCost(), cost.getNetworkCost(), resultHierarchy(planNode))); | |
+ writer.newLine(); | |
+ } | |
+ catch (IOException e) { | |
+ throw new RuntimeException(e); | |
+ } | |
+ } | |
+ | |
+ public void debug(PlanNodeCostEstimate cost, PlanNode planNode, String comment) | |
+ { | |
+ try { | |
+ writer.append(format("[%s] is (cpu: %.0f, mem: %.0f, net: %.0f) (%s) ", | |
+ comment, cost.getCpuCost(), cost.getMemoryCost(), cost.getNetworkCost(), resultHierarchy(planNode))); | |
+ writer.newLine(); | |
+ } | |
+ catch (IOException e) { | |
+ throw new RuntimeException(e); | |
+ } | |
+ } | |
+ | |
+ private String resultHierarchy(PlanNode planNode) { | |
+ String resultId = "" + id(planNode); | |
+ String from = planNode.getSources().stream() | |
+ .mapToInt(this::id) | |
+ .mapToObj(Integer::toString) | |
+ .collect(joining(",")); | |
+ if (from.isEmpty()) { | |
+ return resultId; | |
+ } | |
+ return resultId + " from " + from; | |
+ } | |
+ | |
+ private Integer id(PlanNode planNode) | |
+ { | |
+ return nodeIds.computeIfAbsent(planNode, node -> nodeIdCounter.incrementAndGet()); | |
+ } | |
+ | |
+ @Override | |
+ public void close() | |
+ { | |
+ try { | |
+ writer.flush(); | |
+ writer.close(); | |
+ } | |
+ catch (IOException e) { | |
+ throw new RuntimeException(e); | |
+ } | |
+ } | |
+ } | |
} | |
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java | |
index c244977..11d5a5f 100644 | |
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java | |
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java | |
@@ -89,7 +89,7 @@ public class TestJoinEnumerator | |
TRUE_LITERAL, | |
ImmutableList.of(a1, b1)); | |
ReorderJoins.JoinEnumerator joinEnumerator = new ReorderJoins.JoinEnumerator( | |
- idAllocator, | |
+ null, idAllocator, | |
new SymbolAllocator(), | |
queryRunner.getDefaultSession(), | |
queryRunner.getLookup(), |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment