Skip to content

Instantly share code, notes, and snippets.

@kokosing
Last active July 14, 2017 12:55
Show Gist options
  • Save kokosing/28328984617ce0ad376d0c7338f40203 to your computer and use it in GitHub Desktop.
Save kokosing/28328984617ce0ad376d0c7338f40203 to your computer and use it in GitHub Desktop.
reorder joins debug hack utility
commit c158afc2ce2e1ac05b0fbe4f82edc9011e37057c
Author: Grzegorz Kokosiński <[email protected]>
Date: Fri Jul 14 12:33:27 2017 +0200
Reorder joins debug infra
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
index 9e1f2ad..1f25125 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java
@@ -22,6 +22,7 @@ import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.JoinNodeCachingStatsCalculator;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.PlanNodeIdAllocator;
@@ -33,16 +34,25 @@ import com.facebook.presto.sql.planner.plan.FilterNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.ProjectNode;
+import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
+import com.google.common.collect.TreeTraverser;
+import com.google.common.io.Files;
import io.airlift.log.Logger;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -50,6 +60,8 @@ import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.Stream;
@@ -65,6 +77,7 @@ import static com.facebook.presto.sql.planner.EqualityInference.createEqualityIn
import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT;
+import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.sql.planner.plan.Assignments.identity;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED;
@@ -77,7 +90,9 @@ import static com.google.common.base.Predicates.in;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
+import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
+import static java.util.stream.Collectors.joining;
import static java.util.stream.StreamSupport.stream;
public class ReorderJoins
@@ -115,8 +130,15 @@ public class ReorderJoins
return Optional.empty();
}
- Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(statsCalculator)), new CachingCostCalculator(costCalculator));
- JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols());
+ StatsCalculator statsCalculator = new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(this.statsCalculator));
+ CachingCostCalculator costCalculator = new CachingCostCalculator(this.costCalculator);
+ Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, statsCalculator, costCalculator);
+ Debugger debugger = new Debugger(multiJoinNode, joinCachingStatsLookup, session, symbolAllocator.getTypes());
+ JoinEnumerator joinEnumerator = new JoinEnumerator(debugger, idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator);
+ JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols());
+
+ debugger.debug(result.cost, result.getPlanNode().get(), "best of the best");
+ debugger.close();
return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode();
}
@@ -131,10 +153,12 @@ public class ReorderJoins
private final Expression allFilter;
private final SymbolAllocator symbolAllocator;
private final Lookup lookup;
+ private final Debugger debugger;
@VisibleForTesting
- JoinEnumerator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator)
+ JoinEnumerator(Debugger debugger, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator)
{
+ this.debugger = debugger;
requireNonNull(idAllocator, "idAllocator is null");
requireNonNull(symbolAllocator, "symbolAllocator is null");
requireNonNull(session, "session is null");
@@ -248,7 +272,8 @@ public class ReorderJoins
// create equality inference on available symbols
// TODO: make generateEqualitiesPartitionedBy take left and right scope
- List<Expression> joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities();
+ List<Expression> joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))
+ .getScopeEqualities();
EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[joinEqualities.size()]));
joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities());
@@ -318,7 +343,7 @@ public class ReorderJoins
joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)),
Optional.empty(),
Optional.empty(),
- Optional.empty()));
+ Optional.empty()), leftSources, rightSources);
if (!joinOutputSymbols.equals(sortedOutputSymbols)) {
PlanNode resultNode = new ProjectNode(idAllocator.getNextId(), result.planNode.get(), identity(sortedOutputSymbols));
@@ -364,27 +389,44 @@ public class ReorderJoins
return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip();
}
- private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode)
+ private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode, Set<PlanNode> leftSources, Set<PlanNode> rightSources)
{
+ String leftId = debugger.id(leftSources);
+ String rightId = debugger.id(rightSources);
+
List<JoinEnumerationResult> possibleJoinNodes = new ArrayList<>();
FeaturesConfig.JoinDistributionType joinDistributionType = getJoinDistributionType(session);
if (joinDistributionType.canRepartition() && !joinNode.isCrossJoin()) {
JoinNode node = joinNode.withDistributionType(PARTITIONED);
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ JoinEnumerationResult e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node));
+ possibleJoinNodes.add(e);
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "repartition");
+
node = node.flipChildren();
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node));
+ possibleJoinNodes.add(e);
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "repartition flipped");
}
if (joinDistributionType.canReplicate()) {
JoinNode node = joinNode.withDistributionType(REPLICATED);
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ JoinEnumerationResult e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node));
+ possibleJoinNodes.add(e);
node = node.flipChildren();
- possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node)));
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "replicated");
+
+ e = new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node));
+ possibleJoinNodes.add(e);
+ debugger.debug(leftId, rightId, e.cost, e.planNode.get(), "replicated flipped");
}
if (possibleJoinNodes.stream().anyMatch(result -> result.cost.isUnknown())) {
return UNKNOWN_COST_RESULT;
}
- return resultOrdering.min(possibleJoinNodes);
+ JoinEnumerationResult best = resultOrdering.min(possibleJoinNodes);
+ debugger.debug(leftId, rightId, best.cost, best.planNode.get(), "best");
+
+ return best;
}
+
}
@VisibleForTesting
@@ -413,4 +455,100 @@ public class ReorderJoins
return cost;
}
}
+
+ private static class Debugger implements AutoCloseable
+ {
+ private static final AtomicInteger debuggerIdCounter = new AtomicInteger();
+
+ private final BufferedWriter writer;
+ private final Map<PlanNode, String> nodeNames;
+ private final Map<PlanNode, Integer> nodeIds = new HashMap<>();
+ private final AtomicInteger nodeIdCounter = new AtomicInteger();
+
+ public Debugger(MultiJoinNode multiJoinNode, Lookup lookup, Session session, Map<Symbol, Type> types)
+ {
+ try {
+ this.writer = Files.newWriter(new File("/tmp/reorder_joins.log_" + debuggerIdCounter.incrementAndGet()), Charsets.US_ASCII);
+ }
+ catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+
+ nodeNames = multiJoinNode.getSources().stream()
+ .flatMap(node -> Stream.of(node, lookup.resolve(node)))
+ .collect(ImmutableMap.toImmutableMap(Function.identity(), source -> {
+ TableScanNode tableScanNode = (TableScanNode) searchFrom(source, lookup)
+ .where(TableScanNode.class::isInstance)
+ .findFirst()
+ .get();
+ return tableScanNode.getTable().toString(); }));
+
+ multiJoinNode.getSources().stream()
+ .flatMap(source -> TreeTraverser.<PlanNode>using(node -> lookup.resolve(node).getSources()).preOrderTraversal(source).stream())
+ .forEach(node -> debug(lookup.getCumulativeCost(node, session, types), node, nodeNames.get(node)));
+ }
+
+ private String id(Set<PlanNode> leftSources1)
+ {
+ return leftSources1.stream()
+ .map(nodeNames::get)
+ .map(s -> s.replaceAll("tpch:", ""))
+ .map(s -> s.replaceAll(":sf.*", ""))
+ .sorted()
+ .collect(joining(","));
+ }
+
+ private void debug(String leftId, String rightId, PlanNodeCostEstimate cost, PlanNode planNode, String comment)
+ {
+ try {
+ writer.append(format("%s - %s [%s] is (cpu: %.0f, mem: %.0f, net: %.0f) (%s)",
+ leftId, rightId, comment, cost.getCpuCost(), cost.getMemoryCost(), cost.getNetworkCost(), resultHierarchy(planNode)));
+ writer.newLine();
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void debug(PlanNodeCostEstimate cost, PlanNode planNode, String comment)
+ {
+ try {
+ writer.append(format("[%s] is (cpu: %.0f, mem: %.0f, net: %.0f) (%s) ",
+ comment, cost.getCpuCost(), cost.getMemoryCost(), cost.getNetworkCost(), resultHierarchy(planNode)));
+ writer.newLine();
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private String resultHierarchy(PlanNode planNode) {
+ String resultId = "" + id(planNode);
+ String from = planNode.getSources().stream()
+ .mapToInt(this::id)
+ .mapToObj(Integer::toString)
+ .collect(joining(","));
+ if (from.isEmpty()) {
+ return resultId;
+ }
+ return resultId + " from " + from;
+ }
+
+ private Integer id(PlanNode planNode)
+ {
+ return nodeIds.computeIfAbsent(planNode, node -> nodeIdCounter.incrementAndGet());
+ }
+
+ @Override
+ public void close()
+ {
+ try {
+ writer.flush();
+ writer.close();
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
}
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java
index c244977..11d5a5f 100644
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java
@@ -89,7 +89,7 @@ public class TestJoinEnumerator
TRUE_LITERAL,
ImmutableList.of(a1, b1));
ReorderJoins.JoinEnumerator joinEnumerator = new ReorderJoins.JoinEnumerator(
- idAllocator,
+ null, idAllocator,
new SymbolAllocator(),
queryRunner.getDefaultSession(),
queryRunner.getLookup(),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment