use matplotlib to draw phylogenetic trees from ETE3
# Original code from:
# Minor edits to make it run in Python 3,
# and added label_func argument to control the displayed leaf names.
from itertools import chain
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import numpy as np
def round_sig(x, sig=2):
return round(x, sig - int(np.floor(np.log10(abs(x)))) - 1)
def to_coord(x, y, xmin, xmax, ymin, ymax, plt_xmin, plt_ymin, plt_width, plt_height):
x = (x - xmin) / (xmax - xmin) * plt_width + plt_xmin
y = (y - ymin) / (ymax - ymin) * plt_height + plt_ymin
return x, y
def plot_tree(tree, align_names=False, name_offset=None, max_dist=None, font_size=9, axe=None, label_func=None, **kwargs):
Plots a ete3.Tree object using matploltib.
:param tree: ete Tree object
:param False align_names: if True names will be aligned vertically
:param None max_dist: if defined any branch longer than the given value will be
reduced by this same value.
:param None name_offset: offset relative to tips to write leaf_names. In bL scale
:param 12 font_size: to write text
:param None axe: a matploltib.Axe object on which the tree will be painted.
:param label_func: a function to display leaf names, e.g., lambda x: my_names_dict.get(x, x)
:param kwargs: for tree edge drawing (matplotlib LineCollection)
:param 1 ms: marker size for tree nodes (relative to number of nodes)
:returns: a dictionary of node objects with their coordinates
if axe is None:
axe = plt.subplot(111)
def __draw_edge_nm(c, x):
h = node_pos[c]
hlinec.append(((x, h), (x + c.dist, h)))
return (x + c.dist, h)
def __draw_edge_md(c, x):
h = node_pos[c]
if c in cut_edge:
offset = max_x / 600.
hlinec.append(((x, h), (x + c.dist / 2 - offset, h)))
hlinec.append(((x + c.dist / 2 + offset, h), (x + c.dist, h)))
hlinec.append(((x + c.dist / 2, h - 0.05), (x + c.dist / 2 - 2 * offset, h + 0.05)))
hlinec.append(((x + c.dist / 2 + 2 * offset, h - 0.05), (x + c.dist / 2, h + 0.05)))
axe.text(x + c.dist / 2, h - 0.07, '+%g' % max_dist, va='top',
ha='center', size=2. * font_size / 3)
hlinec.append(((x, h), (x + c.dist, h)))
return (x + c.dist, h)
__draw_edge = __draw_edge_nm if max_dist is None else __draw_edge_md
vlinec = []
vlines = []
hlinec = []
hlines = []
nodes = []
nodex = []
nodey = []
ali_lines = []
# to align leaf names
tree = tree.copy()
max_x = max(n.get_distance(tree) for n in tree.iter_leaves())
coords = {}
node_pos = dict((n2, i) for i, n2 in enumerate(tree.get_leaves()[::-1]))
node_list = tree.iter_descendants(strategy='postorder')
node_list = chain(node_list, [tree])
# reduce branch length
cut_edge = set()
if max_dist is not None:
for n in tree.iter_descendants():
if n.dist > max_dist:
n.dist -= max_dist
if name_offset is None:
name_offset = max_x / 100.
# draw tree
for n in node_list:
style = n._get_style()
x = sum(n2.dist for n2 in n.iter_ancestors()) + n.dist
if n.is_leaf():
name = if label_func is None else label_func(
y = node_pos[n]
if align_names:
axe.text(max_x + name_offset, y, name,
va='center', size=font_size)
ali_lines.append(((x, y), (max_x + name_offset, y)))
axe.text(x + name_offset, y, name,
va='center', size=font_size)
y = np.mean([node_pos[n2] for n2 in n.children])
node_pos[n] = y
# draw vertical line
vlinec.append(((x, node_pos[n.children[0]]), (x, node_pos[n.children[-1]])))
# draw horizontal lines
for child in n.children:
cstyle = child._get_style()
coords[child] = __draw_edge(child, x)
# draw root
__draw_edge(tree, 0)
lstyles = ['-', '--', ':']
hline_col = LineCollection(hlinec, colors=[l['hz_line_color'] for l in hlines],
linestyle=[lstyles[l['hz_line_type']] for l in hlines],
linewidth=[(l['hz_line_width'] + 1.) / 2 for l in hlines])
vline_col = LineCollection(vlinec, colors=[l['vt_line_color'] for l in vlines],
linestyle=[lstyles[l['vt_line_type']] for l in vlines],
linewidth=[(l['vt_line_width'] + 1.) / 2 for l in vlines])
ali_line_col = LineCollection(ali_lines, colors='k')
nshapes = dict((('circle', 'o'), ('square', 's'), ('sphere', 'o')))
shapes = set(n['shape'] for n in nodes)
for shape in shapes:
indexes = [i for i, n in enumerate(nodes) if n['shape'] == shape]
scat = axe.scatter([nodex[i] for i in indexes],
[nodey[i] for i in indexes],
s=0, marker=nshapes.get(shape, shape))
scat.set_sizes([(nodes[i]['size'])**2 / 2 for i in indexes])
scat.set_color([nodes[i]['fgcolor'] for i in indexes])
# scale line
xmin, xmax = axe.get_xlim()
ymin, ymax = axe.get_ylim()
diffy = ymax - ymin
dist = round_sig((xmax - xmin) / 5, sig=1)
ymin -= diffy / 100.
axe.plot([xmin, xmin + dist], [ymin, ymin], color='k')
axe.plot([xmin, xmin], [ymin - diffy / 200., ymin + diffy / 200.], color='k')
axe.plot([xmin + dist, xmin + dist], [ymin - diffy / 200., ymin + diffy / 200.],
axe.text((xmin + xmin + dist) / 2, ymin - diffy / 200., dist, va='top',
ha='center', size=font_size)
return coords
