Skip to content

Instantly share code, notes, and snippets.

@peterk87
Created May 2, 2013 21:41
Show Gist options
  • Save peterk87/5505691 to your computer and use it in GitHub Desktop.
Save peterk87/5505691 to your computer and use it in GitHub Desktop.
Python: hierarchically clustered heatmap using Matplotlib
## {{{ http://code.activestate.com/recipes/578175/ (r1)
### hierarchical_clustering.py
#Copyright 2005-2012 J. David Gladstone Institutes, San Francisco California
#Author Nathan Salomonis - [email protected]
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is furnished
#to do so, subject to the following conditions:
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
#INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
#PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
#HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
#OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
#SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#################
### Imports an tab-delimited expression matrix and produces and hierarchically clustered heatmap
#################
import matplotlib.pyplot as pylab
from matplotlib import mpl
import scipy
import scipy.cluster.hierarchy as sch
import scipy.spatial.distance as dist
import numpy
import string
import time
import sys, os
import getopt
################# Perform the hierarchical clustering #################
def heatmap(x, row_header, column_header, row_method,
column_method, row_metric, column_metric,
color_gradient):#, filename):
print "\nPerforming hiearchical clustering using %s for columns and %s for rows" % (column_metric,row_metric)
"""
This below code is based in large part on the protype methods:
http://old.nabble.com/How-to-plot-heatmap-with-matplotlib--td32534593.html
http://stackoverflow.com/questions/7664826/how-to-get-flat-clustering-corresponding-to-color-clusters-in-the-dendrogram-cre
x is an m by n ndarray, m observations, n genes
"""
### Define the color gradient to use based on the provided name
n = len(x[0]); m = len(x)
if color_gradient == 'red_white_blue':
cmap=pylab.cm.bwr
if color_gradient == 'red_black_sky':
cmap=RedBlackSkyBlue()
if color_gradient == 'red_black_blue':
cmap=RedBlackBlue()
if color_gradient == 'red_black_green':
cmap=RedBlackGreen()
if color_gradient == 'yellow_black_blue':
cmap=YellowBlackBlue()
if color_gradient == 'seismic':
cmap=pylab.cm.seismic
if color_gradient == 'green_white_purple':
cmap=pylab.cm.PiYG_r
if color_gradient == 'coolwarm':
cmap=pylab.cm.coolwarm
### Scale the max and min colors so that 0 is white/black
vmin=x.min()
vmax=x.max()
vmax = max([vmax,abs(vmin)])
vmin = vmax*-1
norm = mpl.colors.Normalize(vmin/2, vmax/2) ### adjust the max and min to scale these colors
### Scale the Matplotlib window size
default_window_hight = 8.5
default_window_width = 12
fig = pylab.figure(figsize=(default_window_width,default_window_hight)) ### could use m,n to scale here
color_bar_w = 0.015 ### Sufficient size to show
## calculate positions for all elements
# ax1, placement of dendrogram 1, on the left of the heatmap
#if row_method != None: w1 =
[ax1_x, ax1_y, ax1_w, ax1_h] = [0.05,0.22,0.2,0.6] ### The second value controls the position of the matrix relative to the bottom of the view
width_between_ax1_axr = 0.004
height_between_ax1_axc = 0.004 ### distance between the top color bar axis and the matrix
# axr, placement of row side colorbar
[axr_x, axr_y, axr_w, axr_h] = [0.31,0.1,color_bar_w,0.6] ### second to last controls the width of the side color bar - 0.015 when showing
axr_x = ax1_x + ax1_w + width_between_ax1_axr
axr_y = ax1_y; axr_h = ax1_h
width_between_axr_axm = 0.004
# axc, placement of column side colorbar
[axc_x, axc_y, axc_w, axc_h] = [0.4,0.63,0.5,color_bar_w] ### last one controls the hight of the top color bar - 0.015 when showing
axc_x = axr_x + axr_w + width_between_axr_axm
axc_y = ax1_y + ax1_h + height_between_ax1_axc
height_between_axc_ax2 = 0.004
# axm, placement of heatmap for the data matrix
[axm_x, axm_y, axm_w, axm_h] = [0.4,0.9,2.5,0.5]
axm_x = axr_x + axr_w + width_between_axr_axm
axm_y = ax1_y; axm_h = ax1_h
axm_w = axc_w
# ax2, placement of dendrogram 2, on the top of the heatmap
[ax2_x, ax2_y, ax2_w, ax2_h] = [0.3,0.72,0.6,0.15] ### last one controls hight of the dendrogram
ax2_x = axr_x + axr_w + width_between_axr_axm
ax2_y = ax1_y + ax1_h + height_between_ax1_axc + axc_h + height_between_axc_ax2
ax2_w = axc_w
# axcb - placement of the color legend
[axcb_x, axcb_y, axcb_w, axcb_h] = [0.07,0.88,0.18,0.09]
# Compute and plot top dendrogram
if column_method != None:
start_time = time.time()
d2 = dist.pdist(x.T)
D2 = dist.squareform(d2)
ax2 = fig.add_axes([ax2_x, ax2_y, ax2_w, ax2_h], frame_on=True)
Y2 = sch.linkage(D2, method=column_method, metric=column_metric) ### array-clustering metric - 'average', 'single', 'centroid', 'complete'
Z2 = sch.dendrogram(Y2)
ind2 = sch.fcluster(Y2,0.7*max(Y2[:,2]),'distance') ### This is the default behavior of dendrogram
ax2.set_xticks([]) ### Hides ticks
ax2.set_yticks([])
time_diff = str(round(time.time()-start_time,1))
print 'Column clustering completed in %s seconds' % time_diff
else:
ind2 = ['NA']*len(column_header) ### Used for exporting the flat cluster data
# Compute and plot left dendrogram.
if row_method != None:
start_time = time.time()
d1 = dist.pdist(x)
D1 = dist.squareform(d1) # full matrix
ax1 = fig.add_axes([ax1_x, ax1_y, ax1_w, ax1_h], frame_on=True) # frame_on may be False
Y1 = sch.linkage(D1, method=row_method, metric=row_metric) ### gene-clustering metric - 'average', 'single', 'centroid', 'complete'
Z1 = sch.dendrogram(Y1, orientation='right')
ind1 = sch.fcluster(Y1,0.7*max(Y1[:,2]),'distance') ### This is the default behavior of dendrogram
ax1.set_xticks([]) ### Hides ticks
ax1.set_yticks([])
time_diff = str(round(time.time()-start_time,1))
print 'Row clustering completed in %s seconds' % time_diff
else:
ind1 = ['NA']*len(row_header) ### Used for exporting the flat cluster data
# Plot distance matrix.
axm = fig.add_axes([axm_x, axm_y, axm_w, axm_h]) # axes for the data matrix
xt = x
if column_method != None:
idx2 = Z2['leaves'] ### apply the clustering for the array-dendrograms to the actual matrix data
xt = xt[:,idx2]
ind2 = ind2[:,idx2] ### reorder the flat cluster to match the order of the leaves the dendrogram
if row_method != None:
idx1 = Z1['leaves'] ### apply the clustering for the gene-dendrograms to the actual matrix data
xt = xt[idx1,:] # xt is transformed x
ind1 = ind1[idx1,:] ### reorder the flat cluster to match the order of the leaves the dendrogram
### taken from http://stackoverflow.com/questions/2982929/plotting-results-of-hierarchical-clustering-ontop-of-a-matrix-of-data-in-python/3011894#3011894
im = axm.matshow(xt, aspect='auto', origin='lower', cmap=cmap, norm=norm) ### norm=norm added to scale coloring of expression with zero = white or black
axm.set_xticks([]) ### Hides x-ticks
axm.set_yticks([])
# Add text
new_row_header=[]
new_column_header=[]
for i in range(x.shape[0]):
if row_method != None:
if len(row_header)<100: ### Don't visualize gene associations when more than 100 rows
axm.text(x.shape[1]-0.5, i, ' '+row_header[idx1[i]])
new_row_header.append(row_header[idx1[i]])
else:
if len(row_header)<100: ### Don't visualize gene associations when more than 100 rows
axm.text(x.shape[1]-0.5, i, ' '+row_header[i]) ### When not clustering rows
new_row_header.append(row_header[i])
for i in range(x.shape[1]):
if column_method != None:
axm.text(i, -0.9, ' '+column_header[idx2[i]], rotation=270, verticalalignment="top") # rotation could also be degrees
new_column_header.append(column_header[idx2[i]])
else: ### When not clustering columns
axm.text(i, -0.9, ' '+column_header[i], rotation=270, verticalalignment="top")
new_column_header.append(column_header[i])
# Plot colside colors
# axc --> axes for column side colorbar
if column_method != None:
axc = fig.add_axes([axc_x, axc_y, axc_w, axc_h]) # axes for column side colorbar
cmap_c = mpl.colors.ListedColormap(['r', 'g', 'b', 'y', 'w', 'k', 'm'])
dc = numpy.array(ind2, dtype=int)
dc.shape = (1,len(ind2))
im_c = axc.matshow(dc, aspect='auto', origin='lower', cmap=cmap_c)
axc.set_xticks([]) ### Hides ticks
axc.set_yticks([])
# Plot rowside colors
# axr --> axes for row side colorbar
if row_method != None:
axr = fig.add_axes([axr_x, axr_y, axr_w, axr_h]) # axes for column side colorbar
dr = numpy.array(ind1, dtype=int)
dr.shape = (len(ind1),1)
#print ind1, len(ind1)
cmap_r = mpl.colors.ListedColormap(['r', 'g', 'b', 'y', 'w', 'k', 'm'])
im_r = axr.matshow(dr, aspect='auto', origin='lower', cmap=cmap_r)
axr.set_xticks([]) ### Hides ticks
axr.set_yticks([])
# Plot color legend
axcb = fig.add_axes([axcb_x, axcb_y, axcb_w, axcb_h], frame_on=False) # axes for colorbar
cb = mpl.colorbar.ColorbarBase(axcb, cmap=cmap, norm=norm, orientation='horizontal')
axcb.set_title("colorkey")
# if '/' in filename:
# dataset_name = string.split(filename,'/')[-1][:-4]
# root_dir = string.join(string.split(filename,'/')[:-1],'/')+'/'
# else:
# dataset_name = string.split(filename,'\\')[-1][:-4]
# root_dir = string.join(string.split(filename,'\\')[:-1],'\\')+'\\'
# filename = root_dir+'Clustering-%s-hierarchical_%s_%s.pdf' % (dataset_name,column_metric,row_metric)
# cb.set_label("Differential Expression (log2 fold)")
# exportFlatClusterData(filename, new_row_header,new_column_header,xt,ind1,ind2)
# ### Render the graphic
# if len(row_header)>50 or len(column_header)>50:
# pylab.rcParams['font.size'] = 5
# else:
# pylab.rcParams['font.size'] = 8
# pylab.savefig(filename)
# print 'Exporting:',filename
# filename = filename[:-3]+'png'
# pylab.savefig(filename, dpi=100) #,dpi=200
pylab.show()
def getColorRange(x):
""" Determines the range of colors, centered at zero, for normalizing cmap """
vmax=x.max()
vmin=x.min()
if vmax<0 and vmin<0: direction = 'negative'
elif vmax>0 and vmin>0: direction = 'positive'
else: direction = 'both'
if direction == 'both':
vmax = max([vmax,abs(vmin)])
vmin = -1*vmax
return vmax,vmin
else:
return vmax,vmin
################# Export the flat cluster data #################
def exportFlatClusterData(filename, new_row_header,new_column_header,xt,ind1,ind2):
""" Export the clustered results as a text file, only indicating the flat-clusters rather than the tree """
filename = string.replace(filename,'.pdf','.txt')
export_text = open(filename,'w')
column_header = string.join(['UID','row_clusters-flat']+new_column_header,'\t')+'\n' ### format column-names for export
export_text.write(column_header)
column_clusters = string.join(['column_clusters-flat','']+ map(str, ind2),'\t')+'\n' ### format column-flat-clusters for export
export_text.write(column_clusters)
### The clusters, dendrogram and flat clusters are drawn bottom-up, so we need to reverse the order to match
new_row_header = new_row_header[::-1]
xt = xt[::-1]
### Export each row in the clustered data matrix xt
i=0
for row in xt:
export_text.write(string.join([new_row_header[i],str(ind1[i])]+map(str, row),'\t')+'\n')
i+=1
export_text.close()
### Export as CDT file
filename = string.replace(filename,'.txt','.cdt')
export_cdt = open(filename,'w')
column_header = string.join(['UNIQID','NAME','GWEIGHT']+new_column_header,'\t')+'\n' ### format column-names for export
export_cdt.write(column_header)
eweight = string.join(['EWEIGHT','','']+ ['1']*len(new_column_header),'\t')+'\n' ### format column-flat-clusters for export
export_cdt.write(eweight)
### Export each row in the clustered data matrix xt
i=0
for row in xt:
export_cdt.write(string.join([new_row_header[i]]*2+['1']+map(str, row),'\t')+'\n')
i+=1
export_cdt.close()
################# Create Custom Color Gradients #################
#http://matplotlib.sourceforge.net/examples/pylab_examples/custom_cmap.html
def RedBlackSkyBlue():
cdict = {'red': ((0.0, 0.0, 0.0),
(0.5, 0.0, 0.1),
(1.0, 1.0, 1.0)),
'green': ((0.0, 0.0, 0.9),
(0.5, 0.1, 0.0),
(1.0, 0.0, 0.0)),
'blue': ((0.0, 0.0, 1.0),
(0.5, 0.1, 0.0),
(1.0, 0.0, 0.0))
}
my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',cdict,256)
return my_cmap
def RedBlackBlue():
cdict = {'red': ((0.0, 0.0, 0.0),
(0.5, 0.0, 0.1),
(1.0, 1.0, 1.0)),
'green': ((0.0, 0.0, 0.0),
(1.0, 0.0, 0.0)),
'blue': ((0.0, 0.0, 1.0),
(0.5, 0.1, 0.0),
(1.0, 0.0, 0.0))
}
my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',cdict,256)
return my_cmap
def RedBlackGreen():
cdict = {'red': ((0.0, 0.0, 0.0),
(0.5, 0.0, 0.1),
(1.0, 1.0, 1.0)),
'blue': ((0.0, 0.0, 0.0),
(1.0, 0.0, 0.0)),
'green': ((0.0, 0.0, 1.0),
(0.5, 0.1, 0.0),
(1.0, 0.0, 0.0))
}
my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',cdict,256)
return my_cmap
def YellowBlackBlue():
cdict = {'red': ((0.0, 0.0, 0.0),
(0.5, 0.0, 0.1),
(1.0, 1.0, 1.0)),
'green': ((0.0, 0.0, 0.8),
(0.5, 0.1, 0.0),
(1.0, 1.0, 1.0)),
'blue': ((0.0, 0.0, 1.0),
(0.5, 0.1, 0.0),
(1.0, 0.0, 0.0))
}
### yellow is created by adding y = 1 to RedBlackSkyBlue green last tuple
### modulate between blue and cyan using the last y var in the first green tuple
my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',cdict,256)
return my_cmap
################# General data import methods #################
def importData(filename):
start_time = time.time()
matrix=[]
row_header=[]
first_row=True
if '/' in filename:
dataset_name = string.split(filename,'/')[-1][:-4]
else:
dataset_name = string.split(filename,'\\')[-1][:-4]
for line in open(filename,'rU').xreadlines():
t = string.split(line[:-1],'\t') ### remove end-of-line character - file is tab-delimited
if first_row:
column_header = t[1:]
first_row=False
else:
if ' ' not in t and '' not in t: ### Occurs for rows with missing data
s = map(float,t[1:])
if (abs(max(s)-min(s)))>0:
matrix.append(s)
row_header.append(t[0])
time_diff = str(round(time.time()-start_time,1))
try:
print '\n%d rows and %d columns imported for %s in %s seconds...' % (len(matrix),len(column_header),dataset_name,time_diff)
except Exception:
print 'No data in input file.'; force_error
return numpy.array(matrix), column_header, row_header
# if __name__ == '__main__':
# ################ Default Methods ################
row_method = 'average'
column_method = 'single'
row_metric = 'cityblock' #cosine
column_metric = 'euclidean'
color_gradient = 'red_white_blue'
# """ Running with cosine or other distance metrics can often produce negative Z scores
# during clustering, so adjustments to the clustering may be required.
# see: http://docs.scipy.org/doc/scipy/reference/cluster.hierarchy.html
# see: http://docs.scipy.org/doc/scipy/reference/spatial.distance.htm
# color_gradient = red_white_blue|red_black_sky|red_black_blue|red_black_green|yellow_black_blue|green_white_purple'
# """
# ################ Comand-line arguments ################
# if len(sys.argv[1:])<=1: ### Indicates that there are insufficient number of command-line arguments
# print "Warning! Please designate a tab-delimited input expression file in the command-line"
# print "Example: python hierarchical_clustering.py --i /Users/me/logfolds.txt"
# sys.exit()
# else:
# options, remainder = getopt.getopt(sys.argv[1:],'', ['i=','row_header','column_method',
# 'row_metric','column_metric','color_gradient'])
# for opt, arg in options:
# if opt == '--i': filename=arg
# elif opt == '--row_header': row_header=arg
# elif opt == '--column_method': column_method=arg
# elif opt == '--row_metric': row_metric=arg
# elif opt == '--column_metric': column_metric=arg
# elif opt == '--color_gradient': color_gradient=arg
# else:
# print "Warning! Command-line argument: %s not recognized. Exiting..." % opt; sys.exit()
# matrix, column_header, row_header = importData(filename)
# if len(matrix)>0:
# try:
# heatmap(matrix, row_header, column_header, row_method, column_method, row_metric, column_metric, color_gradient, filename)
# except Exception:
# print 'Error using %s ... trying euclidean instead' % row_metric
# row_metric = 'euclidean'
# try:
# heatmap(matrix, row_header, column_header, row_method, column_method, row_metric, column_metric, color_gradient, filename)
# except IOError:
# print 'Error with clustering encountered'
## end of http://code.activestate.com/recipes/578175/ }}}
# read in a matrix using above importData function
# matrix, column_header, row_header = importData('matrix_file.txt')
# create biclustered heatmap from imported matrix
# heatmap(matrix, row_header, column_header, row_method, column_method, row_metric, column_metric, color_gradient)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment