Skip to content

Instantly share code, notes, and snippets.

@davipatti
Created January 24, 2019 14:04
Show Gist options
  • Save davipatti/38726f18b7f54399d046e70491be20a4 to your computer and use it in GitHub Desktop.
Save davipatti/38726f18b7f54399d046e70491be20a4 to your computer and use it in GitHub Desktop.
Compute and plot a layout for a dendropy.Tree instance in matplotlib.
def attachX(tree):
for node in tree.preorder_node_iter():
if node.parent_node is None:
node.x = 0
else:
node.x = node.edge.length + node.parent_node.x
def attatchYLeaves(tree):
for y, node in enumerate(tree.leaf_node_iter()):
node.y = y
def attatchYInternal(tree):
for node in tree.postorder_node_iter():
if hasattr(node, "y"):
continue
else:
child_ys = tuple(child.y for child in node.child_node_iter())
node.y = sum(child_ys) / len(child_ys)
def treeXlim(tree):
x = tuple(node.x for node in tree.preorder_node_iter())
return min(x), max(x)
def treeYlim(tree):
y = tuple(node.y for node in tree.preorder_node_iter())
return min(y), max(y)
def drawEdges(tree, **kws):
"""kws pass to LineCollection"""
segments = []
append = segments.append
for edge in tree.edges():
head = edge.head_node
tail = edge.tail_node
if tail is None:
continue
append(((head.x, head.y), (tail.x, tail.y)))
lc = matplotlib.collections.LineCollection(segments=segments, **kws)
ax = plt.gca()
ax.add_artist(lc)
def drawLeaves(tree, **kws):
"""kws passed to plt.scatter"""
x = [node.x for node in tree.leaf_node_iter()]
y = [node.y for node in tree.leaf_node_iter()]
plt.scatter(x, y, **kws)
def drawTree(tree, **kws):
"""
Args:
tree (dendropy.Tree). Must have edge lengths.
"""
# Optional keyword args
linewidth = kws.pop("linewidth", 0.5)
linecolor = kws.pop("linecolor", "black")
leafcolor = kws.pop("leafcolor", "black")
leafsize = kws.pop("leafsize", 1)
# Compute layout
attach_x(tree)
attatchYLeaves(tree)
attatchYInternal(tree)
# Draw
drawEdges(tree, linewidth=linewidth, color=linecolor)
drawLeaves(tree, c=leafcolor, s=leafsize)
# Finalise
plt.ylim(treeYlim(tree))
plt.xlim(treeXlim(tree))
plt.yticks([])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment