Skip to content

Instantly share code, notes, and snippets.

@olgabot
Created October 16, 2014 20:25
Show Gist options
  • Save olgabot/2e3e7eed578e702fa091 to your computer and use it in GitHub Desktop.
Save olgabot/2e3e7eed578e702fa091 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": "",
"signature": "sha256:be61b723a8bb38e4aad32e1f06f98a25c9a73cbe7ba4dd39a31e4390580d42e0"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"%matplotlib inline"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"\n",
"from seaborn.matrix import _HeatMapper\n"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from itertools import chain\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from seaborn.utils import axis_ticklabels_overlap, despine\n",
"from seaborn.axisgrid import Grid\n",
"\n",
"class _DendrogramPlotter(object):\n",
" \"\"\"Plotting object for drawing a tree-diagram of the relationships between columns of data\"\"\"\n",
" \n",
" def __init__(self, data, linkage=None, use_fastcluster=False, metric='euclidean', method='average', axis=1, ax=None):\n",
" \"\"\"Plot a dendrogram of the relationships between the columns of data\n",
" \n",
" Parameters\n",
" ----------\n",
" data : pandas.DataFrame\n",
" Rectangular data\n",
" \n",
" \"\"\"\n",
" if axis == 1:\n",
" data = data.T\n",
"\n",
" if isinstance(data, pd.DataFrame):\n",
" plot_data = data.values\n",
" else:\n",
" plot_data = np.asarray(data)\n",
" data = pd.DataFrame(plot_data)\n",
" \n",
" self.plot_data = plot_data\n",
" self.data = data\n",
" \n",
" self.shape = self.data.shape\n",
" self.use_fastcluster = use_fastcluster\n",
" self.metric = metric\n",
" self.method = method\n",
" self.axis = axis\n",
" \n",
" if ax is None:\n",
" ax = plt.gca()\n",
" self.ax = ax\n",
" \n",
" \n",
" if linkage is None:\n",
" self.linkage = self.calculated_linkage\n",
" else:\n",
" self.linkage = linkage\n",
" \n",
" @property\n",
" def calculated_linkage(self): \n",
"\n",
" if self.use_fastcluster:\n",
" return self.linkage_function(self.array, method=self.method,\n",
" metric=self.metric)\n",
" else:\n",
" from scipy.spatial import distance\n",
"\n",
" pairwise_dists = distance.squareform(\n",
" distance.pdist(self.data, metric=self.metric))\n",
" return self.linkage_function(pairwise_dists, method=self.method)\n",
"\n",
" @property\n",
" def linkage_function(self):\n",
" \"\"\"\n",
" Notes\n",
" -----\n",
" If the product of the number of rows and cols exceeds\n",
" 10000, this wil try to import fastcluster, and raise a warning if it\n",
" does not exist. Vanilla scipy.cluster.hierarchy.linkage will take a\n",
" long time on these matrices.\n",
" \"\"\"\n",
" if np.product(self.shape) >= 10000 or self.use_fastcluster:\n",
" try:\n",
" import fastcluster\n",
"\n",
" self.use_fastcluster = True\n",
" return fastcluster.linkage_vector\n",
" except ImportError:\n",
" raise warnings.warn(\n",
" 'Module \"fastcluster\" not found. The dataframe provided '\n",
" 'has shape {}, and one of the dimensions has greater than '\n",
" '1000 variables. Calculating linkage on such a matrix will'\n",
" ' take a long time with vanilla '\n",
" '\"scipy.cluster.hierarchy.linkage\", and we suggest '\n",
" 'fastcluster for such large datasets'.format(shape),\n",
" RuntimeWarning)\n",
" else:\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" return sch.linkage\n",
" \n",
" @property\n",
" def dendrogram(self):\n",
" \"\"\"Calculates a dendrogram based on the linkage matrix\n",
"\n",
" Parameters\n",
" ----------\n",
" kws : dict\n",
" Keyword arguments for column or row plotting passed to clusterplot\n",
" linkage : numpy.array\n",
" Linkage matrix, usually created by scipy.cluster.hierarchy.linkage\n",
" orientation : str\n",
" (docstring stolen from scipy.cluster.hierarchy.linkage)\n",
" The direction to plot the dendrogram, which can be any\n",
" of the following strings:\n",
"\n",
" 'top' plots the root at the top, and plot descendent\n",
" links going downwards. (default).\n",
"\n",
" 'bottom'- plots the root at the bottom, and plot descendent\n",
" links going upwards.\n",
"\n",
" 'left'- plots the root at the left, and plot descendent\n",
" links going right.\n",
"\n",
" 'right'- plots the root at the right, and plot descendent\n",
" links going left.\n",
"\n",
" Returns\n",
" -------\n",
" dendrogram : dict\n",
" Dendrogram dictionary as returned by scipy.cluster.hierarchy\n",
" .dendrogram. The important key-value pairing is \"leaves\" which\n",
" tells the ordering of the matrix\n",
" \"\"\"\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" sch.set_link_color_palette(['k'])\n",
"\n",
"# if kws['cluster']:\n",
"# dendrogram = \n",
"# else:\n",
"# dendrogram = {'leaves': list(range(self.linkage.shape[0] + 1))}\n",
" return sch.dendrogram(self.linkage, no_plot=True, color_list=['k'], color_threshold=-np.inf)\n",
"\n",
" @property\n",
" def reordered_ind(self):\n",
" return self.dendrogram['leaves']\n",
" \n",
" def plot(self, label=False):\n",
" \"\"\"Plots a dendrogram on the figure at the gridspec location using\n",
" the linkage matrix\n",
"\n",
" Both the computation and plotting must be in this same function because\n",
" scipy.cluster.hierarchy.dendrogram does ax = plt.gca() and cannot be\n",
" specified its own ax object.\n",
"\n",
" Parameters\n",
" ----------\n",
" ax : matplotlib.axes.Axes\n",
" Axes object upon which the dendrogram is plotted\n",
" \"\"\"\n",
" if self.axis == 0:\n",
" X = self.dendrogram['dcoord']\n",
" Y = self.dendrogram['icoord']\n",
" else:\n",
" X = self.dendrogram['icoord']\n",
" Y = self.dendrogram['dcoord']\n",
"\n",
" for x, y in zip(X, Y):\n",
" self.ax.plot(x, y, color='k', linewidth=0.5)\n",
"\n",
" if self.axis == 0:\n",
" self.ax.invert_xaxis()\n",
" ymax = min(map(min, Y)) + max(map(max, Y))\n",
" self.ax.set_ylim(0, ymax)\n",
" else:\n",
" xmax = min(map(min, X)) + max(map(max, X))\n",
" self.ax.set_xlim(0, xmax)\n",
"\n",
" despine(ax=self.ax, bottom=True, left=True)\n",
"# self.ax.set_axis_bgcolor('white')\n",
"# self.ax.grid(False)\n",
" \n",
" if label and self.axis == 0:\n",
" # Get the ytick locations out of the y-values\n",
" yticks = 10*np.arange(self.data.shape[0])+5\n",
" self.ax.set_yticks(range(yticks))\n",
" ytl = self.ax.set_yticklabels(self.data.index[self.reordered_ind])\n",
" if axis_ticklabels_overlap(ytl):\n",
" plt.setp(ytl, rotation=\"vertical\")\n",
" self.ax.set_xticks([]) \n",
" elif not label and self.axis == 0:\n",
" self.ax.set_yticks([])\n",
" elif label and self.axis == 1:\n",
" # Get the xtick locations out of the x-values\n",
" xticks = 10*np.arange(self.data.shape[0])+5\n",
" self.ax.set_xticks(xticks)\n",
" xtl = self.ax.set_xticklabels(self.data.index[self.reordered_ind])\n",
" if axis_ticklabels_overlap(xtl):\n",
" plt.setp(xtl, rotation=\"horizontal\")\n",
" self.ax.set_yticks([])\n",
" elif not label and self.axis == 1:\n",
" self.ax.set_xticks([])\n",
" \n",
" return self\n",
"\n",
"\n",
"def dendrogramplot(data, linkage=None, use_fastcluster=False, axis=1, ax=None,\n",
" label=True, metric='euclidean', method='single', linkage_kws=None,\n",
" rotate=False):\n",
" \"\"\"Draw a tree diagram of the similarities within the rows or columns of a matrix\n",
" \n",
" data : pandas.DataFrame\n",
" Rectangular data\n",
" linkage : numpy.array, optional\n",
" Linkage matrix\n",
" use_fastcluster : bool, default False\n",
" Whether or not to use the \"fastcluster\" package to calculate linkage\n",
" axis : int, default 1\n",
" Which axis to use to calculate linkage \n",
" \n",
" \n",
" \"\"\"\n",
" \n",
" p = _DendrogramPlotter(data, linkage=linkage, \n",
" use_fastcluster=use_fastcluster, \n",
" axis=axis, ax=ax, metric=metric, method=method)\n",
" p.plot(label=label)\n",
" \n",
" return p\n",
"\n",
"class DendrogramGrid(Grid):\n",
" \n",
" def __init__(self, data, fig):\n",
" if fig is None:\n",
" if figsize is None:\n",
" # width = min(self.data2d.shape[1] * 0.5, 40)\n",
" # height = min(self.data2d.shape[0] * 0.5, 40)\n",
" width, height = 10, 10\n",
" figsize = (width, height)\n",
" fig = plt.figure(figsize=figsize)\n",
"\n",
" self.fig = fig\n",
" width_ratios = self.get_fig_width_ratios(self.row_kws['side_colors'],\n",
" figsize=figsize,\n",
" # colorbar_kws['loc'],\n",
" dimension='width')\n",
"\n",
" height_ratios = self.get_fig_width_ratios(self.col_kws['side_colors'],\n",
" figsize=figsize,\n",
" dimension='height')\n",
" # nrows = 3 if self.col_kws['side_colors'] is None else 4\n",
" # ncols = 2 if self.row_kws['side_colors'] is None else 3\n",
" nrows = 3 if self.col_kws['side_colors'] is None else 4\n",
" ncols = 3 if self.row_kws['side_colors'] is None else 4\n",
"\n",
" self.gs = gridspec.GridSpec(nrows, ncols, wspace=0.01, hspace=0.01,\n",
" width_ratios=width_ratios,\n",
" height_ratios=height_ratios)\n",
"\n",
" # self.row_dendrogram_ax = self.fig.add_subplot(self.gs[nrows-1, 0])\n",
" # self.col_dendrogram_ax = self.fig.add_subplot(self.gs[0:2, ncols-1])\n",
" self.ax_row_dendrogram = self.fig.add_subplot(self.gs[nrows-1, 0:2])\n",
" self.ax_col_dendrogram = self.fig.add_subplot(self.gs[0:2, ncols-1])\n",
"\n",
" self.ax_row_side_colors = None\n",
" self.ax_col_side_colors = None\n",
"\n",
" if self.col_kws['side_colors'] is not None:\n",
" self.col_side_colors_ax = self.fig.add_subplot(\n",
" self.gs[nrows - 2, ncols - 1])\n",
" if self.row_kws['side_colors'] is not None:\n",
" self.row_side_colors_ax = self.fig.add_subplot(\n",
" self.gs[nrows - 1, ncols - 2])\n",
"\n",
" self.ax_heatmap = self.fig.add_subplot(self.gs[nrows - 1, ncols - 1])\n",
" # self.heatmap_ax = self.fig.add_subplot(self.gs[nrows-2, ncols-2])\n",
"\n",
" # colorbar for scale to right of heatmap\n",
" self.ax_colorbar = self.fig.add_subplot(self.gs[0, 0])\n",
" # self.colorbar_ax = self.fig.add_subplot(self.gs[nrows-1, ncols-1]) \n",
" \n",
" def savefig(self, *args, **kwargs):\n",
" if 'bbox_inches' not in kwargs:\n",
" kwargs['bbox_inches'] = 'tight'\n",
" self.fig.savefig(*args, **kwargs)\n",
" \n",
" def plot_dendrograms(self, kws):\n",
" \n",
" # PLot the column dendrogram\n",
" self.dendrogram_col = dendrogramplot(self.data2d, use_fastcluster=self.use_fastcluster, \n",
" metric=self.metric, method=self.method, label=False,\n",
" axis=1, ax=self.ax_col_dendrogram)\n",
" \n",
" # Plot the row dendrogram\n",
" self.dendrogram_row = dendrogramplot(self.data2d, use_fastcluster=self.use_fastcluster, \n",
" metric=self.metric, method=self.method, label=False,\n",
" axis=0, ax=self.ax_row_dendrogram)\n",
" pass\n",
" \n",
" def plot_sidecolors(self, kws):\n",
" pass\n",
" \n",
" \n",
" def plot_matrix(self, ax, cax, kws):\n",
" super(_ClusteredHeatMapper, self).plot(ax, cax, kws)\n",
" \n",
"\n",
"class _ClusteredHeatMapper(_HeatMapper):\n",
" pass\n",
"\n"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"%pdb"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Automatic pdb calling has been turned ON\n"
]
}
],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"import matplotlib as mpl"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"mpl.__version__"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 6,
"text": [
"u'1.4.0'"
]
}
],
"prompt_number": 6
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"data = pd.DataFrame(np.random.randn(200).reshape(10, 20))\n",
"\n",
"dendrogramplot(data)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Cannot get window extent w/o renderer",
"output_type": "pyerr",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-1eba80a04260>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mdendrogramplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-3-8aa9b5a76ea8>\u001b[0m in \u001b[0;36mdendrogramplot\u001b[0;34m(data, linkage, use_fastcluster, axis, ax, label, metric, method, linkage_kws, rotate)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[0muse_fastcluster\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_fastcluster\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m axis=axis, ax=ax, metric=metric, method=method)\n\u001b[0;32m--> 218\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-3-8aa9b5a76ea8>\u001b[0m in \u001b[0;36mplot\u001b[0;34m(self, label)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_xticks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxticks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0mxtl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_xticklabels\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreordered_ind\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 189\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0maxis_ticklabels_overlap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxtl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 190\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxtl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrotation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"horizontal\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0max\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_yticks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/olga/workspace-git/seaborn/seaborn/utils.pyc\u001b[0m in \u001b[0;36maxis_ticklabels_overlap\u001b[0;34m(labels)\u001b[0m\n\u001b[1;32m 372\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m \"\"\"\n\u001b[0;32m--> 374\u001b[0;31m \u001b[0mbboxes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_window_extent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 375\u001b[0m \u001b[0moverlaps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcount_overlaps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbboxes\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbboxes\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 376\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moverlaps\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python2.7/site-packages/matplotlib/text.pyc\u001b[0m in \u001b[0;36mget_window_extent\u001b[0;34m(self, renderer, dpi)\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_renderer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 738\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_renderer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 739\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Cannot get window extent w/o renderer'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 740\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[0mbbox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_layout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_renderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Cannot get window extent w/o renderer"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
"> \u001b[0;32m/usr/local/lib/python2.7/site-packages/matplotlib/text.py\u001b[0m(739)\u001b[0;36mget_window_extent\u001b[0;34m()\u001b[0m\n",
"\u001b[0;32m 738 \u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_renderer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m--> 739 \u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Cannot get window extent w/o renderer'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\u001b[0;32m 740 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"stream": "stdout",
"text": [
"ipdb> q\n"
]
},
{
"output_type": "stream",
"stream": "stderr",
"text": [
"/usr/local/lib/python2.7/site-packages/matplotlib/figure.py:1644: UserWarning: This figure includes Axes that are not compatible with tight_layout, so its results might be incorrect.\n",
" warnings.warn(\"This figure includes Axes that are not \"\n"
]
},
{
"metadata": {},
"output_type": "display_data",
"png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGupJREFUeJzt3XtUlHUex/EPlxCVNEs3TSu0GyrrgWACdRCVivCy1mbp\nKW3TdRXdrlu7uqfMUneVzjHbUqTU9cK6ZmGtHsPaUNGOkguheEmrVWpt0w6WsR4UcYZn//DIEdNh\n5hlm+Cnv1zn+4UPPt8/gwGeey/wmxLIsSwAAGCa0qQMAAHAhFBQAwEgUFADASBQUAMBIFBQAwEgU\nFADASOEN/Qdr167VtGnT6m07efKkHnzwQU2fPj1gwQAAzVuIr++D2rZtm6ZMmaJ33nlH1157baBy\nAQCaOZ8KqqqqShkZGZo2bZrS0tICmQsA0Mz5dA1q0aJFiomJoZwAAAHX4DWos6qqqrRixQotWrQo\nkHkAAJDkwxFUQUGBOnfurF69egUyDwAAknwoqE2bNikjIyOQWQAAqON1QZWVlSkuLi6QWQAAqONV\nQbndbn333Xfq0KFDoPMAACDJxvugAAAIBq/v4guk2bNnq7q6uqljAMaLjIzUlClTmjoGEBRGFFR1\ndbVefPHFpo4BGI+fEzQnLBYLADASBQUAMBIFBQAwkhHXoC4F3MgBE+zfv5/rUGhywbpZh4LyEjdy\nAMAZwfpdyCk+AICROIICAEjy/lKGt6ea/T0VSEEBACQ1/qUMf2dxig8AYCQKCgBgpMvuFF+gbgcP\nxO29rKsGABd32RXUpXQ7+KWSEwCaAqf4AABGuuyOoBpbIFeQCOSqAJw+BHCpo6AacCmdMjzXpZgZ\nAM7FKT4AgJEoKACAkSgoAICRKCgAgJG4SSLIgvW5UsH63CDuFgQQKBRUkF2qdwVezOX0WACYhVN8\nAAAjNVhQR44c0YQJE5SQkKDU1FTl5uYGIxcAoJnzeIrPsixNmjRJvXv3VnZ2tsrLy/Xwww/r5z//\nueLi4oKVER4E65rWxQTrWpcnXAcDLk8eC6qsrEwVFRV69tlnFRISoptvvllvvfWW2rVrF6x8aMDl\ndk3Ljub++IHLlcdTfHv37tUtt9yil19+WU6nU+np6SorK9NVV10VrHwAgGbK4xFUZWWltm/fruTk\nZBUWFmr37t0aN26cunTposTExGBlBAA0Qx6PoCIiItS2bVuNHz9e4eHhio+P1913360NGzYEKx8A\noJnyWFDdunWT2+1WbW1t3Ta32x3wUAAAeCyovn37KjIyUvPmzZPb7VZpaakKCgqUkZERrHwAgGbK\n4zWoFi1aKDc3V9OnT1efPn0UFRWlqVOnqlevXsHKBwBophpc6uiGG27QokWLgpEFAIA6LHUEADAS\nBQUAMBIFBQAwEgUFADASnweFOk298KxdJixYaweL3AKeUVCo01QLz/pbjDExMY2YJnh27tx5SRar\nPyhl+IKCQpO7HFZkt1Oyl2qx+qM5lrIdFPkZFBTQCC6HkoU5eC6dwU0SAAAjUVAAACNRUAAAI1FQ\nAAAjUVAAACNRUAAAI1FQAAAjUVAAACNRUAAAI1FQAAAjUVAAACNRUAAAI7FYLAAEiber3nv7GWeX\n+6rnFBSarcb8gMbG/tDEy/0XT3PV2KveX+6rnlNQaLZM/ogMU3MBwcQ1KACAkSgoAICRGiyoxYsX\nKzY2VvHx8XV/Pv3002BkAwA0Yw1eg9q3b5+eeeYZjRkzJhh5AACQ5MUR1L59+xQTExOMLAAA1PFY\nUCdPnlR5ebmWLVsmp9OpQYMGafXq1cHKBgBoxjye4vv++++VkJCghx56SH369NHOnTs1ceJEdejQ\nQf369QtWRgBAM+SxoLp06aLc3Ny6vycmJmrYsGEqKCigoIBzNOabfqXGfeMvb/rFpcpjQe3Zs0db\nt27VhAkT6rZVV1erVatWAQ8GXEoa+02/jVl4O3fupOxwSfJYUFFRUcrOzlZ0dLTuuusubd++Xfn5\n+VqxYkWw8gHNkqmFR9khmDwWVHR0tF577TXNmTNHU6ZMUadOnZSVlaXu3bsHKx+ARmDisk6m5YF5\nGnwfVGpqqlJTU4ORBZcZVm4G4A8Wi0XAsHIzAH+wFh8AwEgUFADASBQUAMBIFBQAwEgUFADASBQU\nAMBI3GYOwGeNsTJFY6w3yHvjLm8UFACfNeZ73PwpO3+WXqLczEdBAWhSTbXuoLflRpE1HQoKwGWF\nFUwuH9wkAQAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUF\nADASBQUAMBIFBQAwktcFdfToUfXu3VuFhYUBjAMAwBleF9Rzzz2nyspKhYSEBDIPAACSvCyolStX\nqlWrVurYsWOg8wAAIMmLgiovL9fSpUv50C4AQFB5LCiXy6XJkydr6tSpatu2bbAyAQDguaCys7MV\nExMjp9NZt82yrICHAgDAY0GtX79e+fn5cjgccjgcOnz4sJ5++mktXLgwWPkAAM1UuKcvrl+/vt7f\nBw4cqGnTpik1NTWgoQAA4I26AAAjeTyCOt/GjRsDlQMAgHo4ggIAGImCAgAYiYICABiJggIAGImC\nAgAYiYICABiJggIAGImCAgAYiYICABiJggIAGImCAgAYiYICABiJggIAGImCAgAYiYICABiJggIA\nGImCAgAYiYICABiJggIAGImCAgAYiYICABiJggIAGImCAgAYiYICABipwYLKz89XRkaG4uPjNWTI\nEBUUFAQjFwCgmQv39MXy8nI999xzWrJkieLi4lRUVKTx48fr448/1lVXXRWsjACAZshjQXXt2lXb\ntm1Ty5Yt5XK5VFFRoaioKF1xxRXBygcAaKY8FpQktWzZUocOHVJ6erosy9JLL72k1q1bByMbAKAZ\na7CgJOm6667T7t27VVxcrIkTJ+qGG25QcnJyoLMBAJoxr+7iCwsLU1hYmJKTk5Wens6NEgCAgPNY\nUJs3b9aYMWPqbaupqVHbtm0DGgoAAI8F1bNnT+3Zs0dr1qxRbW2tNm/erC1btmjIkCHBygcAaKY8\nFlT79u21YMECLV++XA6HQ6+//rqys7PVtWvXYOUDADRTDd4kkZiYqNWrVwcjCwAAdVjqCABgJAoK\nAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBg\nJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYKQG\nC6qkpEQPPPCAEhMTddddd2nVqlXByAUAaObCPX2xsrJSkyZN0rRp0zR48GB99tlnGjNmjG644Qb1\n7t07WBkBAM2QxyOow4cPa8CAARo8eLAkqUePHkpKSlJpaWlQwgEAmi+PBRUTE6OsrKy6v1dWVqqk\npETdu3cPeDAAQPPm9U0Sx48fV2ZmpmJjYzVw4MBAZgIAwLuCOnTokEaOHKl27dpp3rx5gc4EAEDD\nBbV3716NGDFC/fr1U3Z2tiIiIoKRCwDQzHm8i+/o0aMaN26cfv3rX2vcuHHBygQAgOcjqLy8PB07\ndkzz589XfHx83Z9XX301WPkAAM2UxyOozMxMZWZmBisLAAB1WOoIAGAkCgoAYCQKCgBgJAoKAGAk\nCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoK\nAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJAoKAGAkCgoAYCQKCgBgJJ8LateuXUpJSQlEFgAA\n6nhdUJZlKS8vT2PHjpXL5QpkJgAAvC+onJwc5ebmauLEibIsK5CZAADwvqCGDx+uNWvWKDY2NpB5\nAACQ5ENBdejQIZA5AACoh7v4AABGoqAAAEaioAAARrJVUCEhIY2dAwCAenwuqKSkJBUVFQUiCwAA\ndTjFBwAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUFADAS\nBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUF\nADASBQUAMBIFBQAwklcF9dlnn2n48OGKj4/Xvffeq7KyskDnAgA0cw0W1KlTp5SZmanhw4erpKRE\no0eP1sSJE3XixIlg5AMANFMNFtQnn3yisLAwjRw5UmFhYbr//vt1zTXXaPPmzcHIBwBophosqPLy\nct100031tnXt2lUHDx4MWCgAABosqBMnTqhly5b1trVs2VLV1dUBCwUAQIhlWZan/2Dp0qXaunWr\nFi5cWLftiSeeUI8ePZSZmRnwgACA5qnBI6hu3bqpvLy83rby8nLdfPPNAQsFAECDBZWcnKyamhr9\n7W9/0+nTp5WXl6cffvhBTqczGPkAAM1Ug6f4JOnzzz/XtGnT9MUXXyg6OlovvviievXqFYx8AIBm\nyquCAgAg2FjqCABgJAoKAGAkCgoAYCQKCgBgJAoKAGCkJiuokpISPfDAA0pMTNRdd92lVatW+TXv\nyJEjmjBhghISEpSamqrc3Fy/5uXn5ysjI0Px8fEaMmSICgoK/Jq3ceNGDRkyRLfffrvuuecerVu3\nzucZu3btUkpKSt3fjxw5okmTJikpKUlOp1MzZ85UTU2N7Xm7d+9W9+7dFR8fX/fnzTff9Cnj2rVr\n6+0fHx+vmJgYvfDCCz7NuVC+mpoazZgxQ8nJyUpKStLzzz+v06dP+zz3YvP92Xf//v16+OGH655/\n2dnZfs07cOCAHnnkETkcDjmdTr3yyiuyc8NtY3xUzvnZKisr9dvf/laJiYkaMGCA8vLy/Jp3Vm1t\nrUaPHq2srCyfM57r6NGj6t27twoLC23tf26+b7/99ifP5549eyo9Pd32zLOOHTumtLQ0/fvf/7Y9\n59ChQxo3bpwcDofS09P1j3/8w6dcZy1evFixsbH1Huenn35qa5YkFRUV6d5779Xtt9+ukSNHateu\nXfYGWU3gxx9/tBwOh7Vu3TrLsixr79691h133GFt27bN1rza2lrrvvvus15++WXL5XJZX375pXXH\nHXdYO3bssDXv4MGDVlxcXN3+27Zts2JjY61jx47ZmnfixAkrNjbW+vDDDy3Lsqzi4mKrZ8+e1n//\n+1+v9q+trbXeeecdKyEhwUpOTq7bPmrUKGvGjBnWqVOnrIqKCuvBBx+05s6da3veqlWrrAkTJvj4\n6DzbunWrlZKSYh05csTrfS6Wb9asWdavfvUrq7Ky0vrxxx+tESNGWDk5OT5nuth8u/u63W5rwIAB\n1vLlyy3Lsqxvv/3Wcjqd1oYNG2xneeihh6xZs2ZZbrfbOnLkiJWWlma99957PmWtrq62UlJSrJUr\nV1oul8vKy8uzevfubVVVVdl+rJZlWY8//rj1hz/8wTp16pRVVlZm3XHHHdbOnTttzztr4cKFVvfu\n3a2srCzvH+QFjB8/3urevbtVWFjo037ePC8qKiosp9Npffzxx37NLC4utu655x4rJibG+vLLL23N\ncblc1pAhQ6znnnvOOnXqlPXFF19Yffv29flxW5ZlPfPMM9Zf//pXn/e7kEOHDllxcXHWO++8Y7nd\nbuuDDz6wHA6HVVFR4fOsJjmCOnz4sAYMGKDBgwdLknr06KGkpCSVlpbamldWVqaKigo9++yzCgsL\n080336y33npL0dHRtuZ17dpV27ZtU1xcnFwulyoqKhQVFaUrrrjC1ryQkBC1bt1aLpdLlmUpJCRE\nV1xxhcLCwrzaPycnR7m5uZo4cWLdq+iamhq1bt1aEydOVEREhNq3b6+hQ4dqx44dtuZJZ15tx8TE\n2HqMF1JVVaUpU6Zo2rRpuvbaa73e70L5Tp8+rbfffltTp05VmzZt1LZtW7322msaOnSoz7ku9vjt\n7hsaGqr8/HyNHj1almXphx9+UG1tra666irbWaKiouRyueR2u2VZlkJDQ3+yaHND/P2onAtlq6qq\n0oYNG/T4448rIiJCvXr10tChQ7165e7p+75//3699957uvPOO20dKZ61cuVKtWrVSh07dvR5X2+e\nFy+88IIGDRrk9Uo6F5pZUlKip556SpmZmV4/1gvN+eqrr3TgwAE9//zzioiI0C233KIHH3xQq1ev\n9mrmufbt29doP/tbtmzRbbfdpuHDhys0NFTp6em69dZb9cEHH/g8q0kKKiYmpt5hfGVlpUpKStS9\ne3db8/bu3atbbrlFL7/8spxOp9LT01VWVubVL4iLadmypQ4dOqRevXpp8uTJevrpp9W6dWtbsyIj\nI5WVlaU//vGPio2N1ahRo/TCCy94/Ut7+PDhWrNmjWJjY+u2RUREKCcnR9dcc03dto0bN3r1PbzQ\nPOnMk7S0tFRpaWkaMGCAsrKyfDpleL5FixYpJiZGaWlpPu13oXxff/213G63ysrKlJ6ern79+mnp\n0qX62c9+5nOuiz1+f/aNjIyUJN155526//771bdvX8XHx9ueN3XqVG3YsEFxcXHq37+/EhISfD6t\n5O9H5Vzs3yE8PFxdunSp2xYdHe3VzIs91pqaGk2ZMkUzZ860/TMmnXm8S5cu1Ysvvmhr/4aeF0VF\nRdqxY4eeeuopv2beeuut2rhxo4YNG+bXHLfbrbCwsHovnENCQvTVV195PVeSTp48qfLyci1btkxO\np1ODBg2yVXJnWZalFi1a1NtmJ5dkwE0Sx48fV2ZmpmJjYzVw4EBbMyorK7V9+3a1a9dOhYWFmj17\ntmbMmKGSkhK/sl133XXavXu3lixZolmzZumTTz6xNeebb77R7373O82cOVNlZWXKycnRn/70J+3f\nv9+r/Tt06ODx65ZlaebMmfrqq680fvx42/OuvvpqDRw4UO+//76WL1+u7du36/XXX/cq4/mqqqq0\nYsUKPfbYYz7ve6F8P/74o06fPq3CwkKtXr1ab7/99k9W2fdnfmPtu379en300Ufas2eP5s+fb2te\nbW2tJk2apLS0NJWWlur9999XSUmJz9dp/f2onAtlO3HiRF0ZnxUZGenVzIt97+bMmaOUlJS6Qg8J\nCfEq37lcLpcmT56sqVOnqm3btj7v7ynfWW+++abGjh3r05HshWa2adNGERERfmfr1q2bOnfurDlz\n5qimpkZffvml3n33XZ9fVH7//fdKSEjQQw89pMLCQk2fPl2zZ8/Wli1bfJpzltPp1K5du/Thhx/K\n5XKpoKBAO3futPVit0kL6tChQxo5cqTatWunefPm2Z4TERGhtm3bavz48QoPD1d8fLzuvvtubdiw\nwa98YWFhCgsLU3JystLT023fKFFQUKAePXpo6NChCg8PV2pqqvr37681a9b4lU+Sqqur9eSTT2rr\n1q3Kzc3V1VdfbXvWggUL9OijjyoyMlLXX3+9MjMz9dFHH9maVVBQoM6dOzfamo0RERGqra3Vk08+\nqaioKHXs2FFjxozx++aVxhYREaHrr79e48aN0z//+U9bMz7//HMdPHhQkydPVosWLXTTTTdp/Pjx\nPhdUq1atflIcJ0+e9OsopWXLljp16lS9bdXV1WrVqpWteUVFRdq+fbueeOIJSWdebNk5xZedna2Y\nmJh6p978OVV4vsOHD6u4uFgPPPBAo830V3h4uLKzs7Vv3z6lpKRo5syZGjZsmK688kqf5nTp0kW5\nubnq16+fwsPDlZiYqGHDhtn+2brxxhs1d+5czZ8/X06nUwUFBUpLS1ObNm18ntVkBbV3716NGDFC\n/fr1U3Z2ts+vKM7VrVs3ud1u1dbW1m1zu922523evFljxoypt62mpsb2K7PIyMif/FCHhYUpPDzc\ndkbpzFHFqFGj9L///U+rVq1S586dbc+qrKzUrFmzVFVVVbeturr6J6+WvbVp0yZlZGTYznO+6Oho\nhYaG1nsVdvaaXlP74YcflJaWpsrKyrpt/jxfIiIiZFlWvTsUQ0NDfb4GGoiPyrnxxht1+vRpHT58\nuFFmrl+/Xv/5z3/Up08fORwOrVu3TitWrPD5s+bWr1+v/Px8ORwOORwOHT58WE8//bStI+wL2bRp\nk5KSkvy6bNDYLMtSVVWVFi9erO3bt2vZsmU6fvy4evTo4dOcPXv26I033qi3zZ+f/aqqKnXq1Elr\n167VJ598otmzZ+vAgQM+55KaqKCOHj2qcePGaezYsZo8ebLf8/r27avIyEjNmzdPbrdbpaWlKigo\nsP0LsmfPntqzZ4/WrFmj2tpabd68WVu2bNGQIUNszevfv78OHjyod999V5Zl6V//+pcKCgp0zz33\n2JonnXlyPv744+rQoYMWLVpk69XJua688kpt2rRJ8+bNk8vl0tdff6033nhDv/zlL23NKysrU1xc\nnF+ZztWmTRvdeeedeuWVV3T8+HF99913WrZsmQYNGtRo/w+7rr76arVv315z587V6dOndeDAAS1e\nvFj333+/rXldu3bVbbfdptmzZ6umpkbffPONlixZ4vNjDcRH5URFRSktLU1z5sxRdXW1du3apXXr\n1tm6WUWSpk+frtLSUhUXF6u4uFhDhw7VqFGjlJOT49Oc9evXq6SkpG5Op06d9Oqrr+o3v/mNrVzn\nKysr8+qaYjCFhITomWee0dtvv63a2loVFRVp7dq1GjFihE9zoqKilJ2drQ8//LBuTn5+vu677z5b\nuY4dO6aRI0dq3759qqmp0dKlS1VZWWnrEk6TFFReXp6OHTum+fPn17vv/tVXX7U1r0WLFsrNzdWu\nXbvUp08f/f73v9fUqVNtn15q3769FixYoOXLl8vhcOj1119Xdna2unbtamtex44dlZOTo5UrV8rh\ncGjGjBnKyspSz549fZ519vz8jh07VFxcrKKiIjkcjrrv4ejRo23NCw0N1RtvvKHPP/9cycnJevjh\nh5WRkaFHHnnE54xut1vfffedX9d6zs8nSbNmzVKnTp00aNAg/eIXv5DT6dTYsWMbbb4/+/7lL3/R\nkSNH1LdvX2VmZurRRx/Vvffea2teaGio5s+fr6NHjyolJUWPPPKIBg8e7PO/RUREhBYuXKh169Yp\nKSlJf//737VgwQJbr4zPfawzZsyQy+VSamqqnnzySU2ePNnnnzV/vu/BcH6+b7/91u/n84Ues53v\nw7n7zJ07V3l5eUpISNCf//xnzZo1y+cjlejoaL322muaP3++EhIS6n4/2b1prUuXLnrppZf02GOP\nqXfv3tq4caOWLFli73lnmXCOBACA8zT5XXwAAFwIBQUAMBIFBQAwEgUFADASBQUAMBIFBQAwEgUF\nADASBQUAMNL/AREPSWuETQnOAAAAAElFTkSuQmCC\n",
"text": [
"<matplotlib.figure.Figure at 0x107e2b490>"
]
}
],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from collections import Iterable\n",
"\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.gridspec as gridspec\n",
"import numpy as np\n",
"import warnings\n",
"\n",
"from . import utils\n",
"\n",
"\n",
"class _MatrixPlotter(object):\n",
" \"\"\"Plotter for 2D matrix data\n",
"\n",
" This will be used by the `clusteredheatmap`\n",
" \"\"\"\n",
"\n",
" def establish_variables(self, data, z_score=None,\n",
" standard_scale=None, **kws):\n",
" \"\"\"Extract variables from data or use directly.\"\"\"\n",
" self.data = data\n",
"\n",
" # Either the data is already in 2d matrix format, or need to do a pivot\n",
" if 'pivot_kws' in kws and kws['pivot_kws'] is not None:\n",
" self.data2d = self.data.pivot(**kws['pivot_kws'])\n",
" else:\n",
" self.data2d = self.data\n",
"\n",
" # if z_score is not None:\n",
" # if z_score\n",
"\n",
"\n",
" def z_score(self, data2d, axis=1):\n",
" \"\"\"Standarize the mean and variance of the data axis\n",
"\n",
" Parameters\n",
" ----------\n",
" data2d : pandas.DataFrame\n",
" Data to normalize\n",
" axis : int\n",
" Which axis to normalize across. If 0, normalize across rows, if 1,\n",
" normalize across columns. Default 1 (across columns)\n",
"\n",
" Returns\n",
" -------\n",
" normalized : pandas.DataFrame\n",
" Noramlized data with a mean of 0 and variance of 1 across the\n",
" specified axis.\n",
" \"\"\"\n",
" if axis == 1:\n",
" z_scored = data2d\n",
" else:\n",
" z_scored = data2d.T\n",
"\n",
" z_scored = (z_scored - z_scored.mean()) / z_scored.var()\n",
"\n",
" if axis == 1:\n",
" return z_scored\n",
" else:\n",
" return z_scored.T\n",
"\n",
" def standard_scale(self, data2d, axis=1, vmin=0):\n",
" \"\"\"Divide the data by the difference between the max and min\n",
"\n",
" Parameters\n",
" ----------\n",
" data2d : pandas.DataFrame\n",
" Data to normalize\n",
" axis : int\n",
" Which axis to normalize across. If 0, normalize across rows, if 1,\n",
" normalize across columns. Default 1 (across columns)\n",
" vmin : int\n",
" If 0, then subtract the minimum of the data before dividing by\n",
" the range.\n",
"\n",
" Returns\n",
" -------\n",
" standardized : pandas.DataFrame\n",
" Noramlized data with a mean of 0 and variance of 1 across the\n",
" specified axis.\n",
"\n",
" >>> import numpy as np\n",
" >>> d = np.arange(5, 8, 0.5)\n",
" >>> standard_scale(d)\n",
" [ 0. 0.2 0.4 0.6 0.8 1. ]\n",
" >>> standard_scale(d, vmin=None)\n",
" [ 2. 2.2 2.4 2.6 2.8 3. ]\n",
" \"\"\"\n",
" # Normalize these values to range from -1 to 1\n",
" if axis == 1:\n",
" standardized = data2d\n",
" else:\n",
" standardized = data2d.T\n",
"\n",
" if vmin == 0:\n",
" subtract = standardized.min()\n",
"\n",
" standardized = (standardized - subtract) / (\n",
" standardized.max() - standardized.min())\n",
"\n",
" if axis == 1:\n",
" return standardized\n",
" else:\n",
" return standardized.T\n",
"\n",
" def plot(self, *args, **kwargs):\n",
" raise NotImplementedError\n",
"\n",
"\n",
"class _ClusteredHeatmapPlotter(_MatrixPlotter):\n",
" \"\"\"Plotter of 2d matrix data, hierarchically clustered with dendrograms\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, data, pivot_kws=None,\n",
" color_scale='linear', linkage_method='median',\n",
" metric='euclidean', pcolormesh_kws=None,\n",
" dendrogram_kws=None,\n",
" row_kws=None, col_kws=None,\n",
" colorbar_kws=None,\n",
" use_fastcluster=False,\n",
" data_na_ok=None, labelsize=18, z_score=None,\n",
" standard_scale=None):\n",
" self.data = data\n",
" self.pivot_kws = pivot_kws\n",
" self.color_scale = color_scale\n",
" self.linkage_method = linkage_method\n",
" self.metric = metric\n",
" self.use_fastcluster = use_fastcluster\n",
"\n",
" self.labelsize = labelsize\n",
"\n",
" self.establish_variables(data, pivot_kws=pivot_kws, z_score=z_score,\n",
" standard_scale=standard_scale)\n",
" self.validate_data_na_ok(data_na_ok)\n",
" self.interpret_kws(row_kws, col_kws, pcolormesh_kws,\n",
" dendrogram_kws, colorbar_kws)\n",
" self.get_linkage()\n",
" self.row_dendrogram = self.calculate_dendrogram(self.row_kws,\n",
" self.row_linkage)\n",
" self.col_dendrogram = self.calculate_dendrogram(self.col_kws,\n",
" self.col_linkage)\n",
"\n",
"\n",
"\n",
"\n",
" def establish_axes(self, fig=None, figsize=None):\n",
" if fig is None:\n",
" if figsize is None:\n",
" # width = min(self.data2d.shape[1] * 0.5, 40)\n",
" # height = min(self.data2d.shape[0] * 0.5, 40)\n",
" width, height = 10, 10\n",
" figsize = (width, height)\n",
" fig = plt.figure(figsize=figsize)\n",
"\n",
" self.fig = fig\n",
" width_ratios = self.get_fig_width_ratios(self.row_kws['side_colors'],\n",
" figsize=figsize,\n",
" # colorbar_kws['loc'],\n",
" dimension='width')\n",
"\n",
" height_ratios = self.get_fig_width_ratios(self.col_kws['side_colors'],\n",
" figsize=figsize,\n",
" dimension='height')\n",
" # nrows = 3 if self.col_kws['side_colors'] is None else 4\n",
" # ncols = 2 if self.row_kws['side_colors'] is None else 3\n",
" nrows = 3 if self.col_kws['side_colors'] is None else 4\n",
" ncols = 3 if self.row_kws['side_colors'] is None else 4\n",
"\n",
" self.gs = gridspec.GridSpec(nrows, ncols, wspace=0.01, hspace=0.01,\n",
" width_ratios=width_ratios,\n",
" height_ratios=height_ratios)\n",
"\n",
" # self.row_dendrogram_ax = self.fig.add_subplot(self.gs[nrows-1, 0])\n",
" # self.col_dendrogram_ax = self.fig.add_subplot(self.gs[0:2, ncols-1])\n",
" self.row_dendrogram_ax = self.fig.add_subplot(self.gs[nrows-1, 0:2])\n",
" self.col_dendrogram_ax = self.fig.add_subplot(self.gs[0:2, ncols-1])\n",
"\n",
" self.row_side_colors_ax = None\n",
" self.col_side_colors_ax = None\n",
"\n",
" if self.col_kws['side_colors'] is not None:\n",
" self.col_side_colors_ax = self.fig.add_subplot(\n",
" self.gs[nrows - 2, ncols - 1])\n",
" if self.row_kws['side_colors'] is not None:\n",
" self.row_side_colors_ax = self.fig.add_subplot(\n",
" self.gs[nrows - 1, ncols - 2])\n",
"\n",
" self.heatmap_ax = self.fig.add_subplot(self.gs[nrows - 1, ncols - 1])\n",
" # self.heatmap_ax = self.fig.add_subplot(self.gs[nrows-2, ncols-2])\n",
"\n",
" # colorbar for scale to right of heatmap\n",
" self.colorbar_ax = self.fig.add_subplot(self.gs[0, 0])\n",
" # self.colorbar_ax = self.fig.add_subplot(self.gs[nrows-1, ncols-1])\n",
"\n",
" def interpret_kws(self, row_kws, col_kws, pcolormesh_kws,\n",
" dendrogram_kws, colorbar_kws):\n",
" \"\"\"Set defaults for keyword arguments\n",
" \"\"\"\n",
" # Interpret keyword arguments\n",
" self.row_kws = {} if row_kws is None else row_kws\n",
" self.col_kws = {} if col_kws is None else col_kws\n",
"\n",
" if 'side_colors' in self.row_kws and self.row_kws['side_colors'] \\\n",
" is not None:\n",
" assert len(self.row_kws['side_colors']) == self.data2d.shape[0]\n",
" if 'side_colors' in self.col_kws and self.col_kws['side_colors'] \\\n",
" is not None:\n",
" assert len(self.col_kws['side_colors']) == self.data2d.shape[1]\n",
"\n",
" for kws in (self.row_kws, self.col_kws):\n",
" kws.setdefault('linkage_matrix', None)\n",
" kws.setdefault('cluster', True)\n",
" kws.setdefault('label', True)\n",
" kws.setdefault('fontsize', 18)\n",
" kws.setdefault('side_colors', None)\n",
"\n",
" self.dendrogram_kws = {} if dendrogram_kws is None else dendrogram_kws\n",
" self.dendrogram_kws.setdefault('color_threshold', np.inf)\n",
" self.dendrogram_kws.setdefault('color_list', ['k'])\n",
" # even if the user specified no_plot as False, override because we\n",
" # have to control the plotting\n",
" if 'no_plot' in self.dendrogram_kws and \\\n",
" not self.dendrogram_kws['no_plot']:\n",
" warnings.warn('Cannot specify \"no_plot\" as False in '\n",
" 'dendrogram_kws')\n",
" self.dendrogram_kws['no_plot'] = True\n",
"\n",
" # Pcolormesh keyword arguments take more work\n",
" self.pcolormesh_kws = {} if pcolormesh_kws is None else pcolormesh_kws\n",
" self.pcolormesh_kws.setdefault('edgecolor', 'white')\n",
" self.pcolormesh_kws.setdefault('linewidth', 0.1)\n",
"\n",
" self.vmin = None if 'vmin' not in self.pcolormesh_kws else \\\n",
" self.pcolormesh_kws['vmin']\n",
" self.vmax = None if 'vmax' not in self.pcolormesh_kws else \\\n",
" self.pcolormesh_kws['vmax']\n",
" self.norm = None if 'norm' not in self.pcolormesh_kws else \\\n",
" self.pcolormesh_kws['norm']\n",
" self.cmap = None if 'cmap' not in self.pcolormesh_kws else \\\n",
" self.pcolormesh_kws['cmap']\n",
" self.edgecolor = self.pcolormesh_kws['edgecolor']\n",
" self.linewidth = self.pcolormesh_kws['linewidth']\n",
"\n",
" # Check if the matrix has values both above and below zero, or only\n",
" # above or only below zero. If both above and below, then the data is\n",
" # \"divergent\" and we will use a colormap with 0 centered at white,\n",
" # negative values blue, and positive values red. Otherwise, we will use\n",
" # the YlGnBu colormap.\n",
" vmax = self.data2d.max().max() if 'vmax' not in self.pcolormesh_kws \\\n",
" else self.pcolormesh_kws['vmax']\n",
" vmin = self.data2d.min().min() if 'vmin' not in self.pcolormesh_kws \\\n",
" else self.pcolormesh_kws['vmin']\n",
" log = self.color_scale == 'log'\n",
" self.divergent = (vmax > 0 and vmin < 0) and not log\n",
"\n",
" if self.color_scale == 'log':\n",
" if self.vmin is None:\n",
" self.vmin = self.data2d.replace(0, np.nan).dropna(\n",
" how='all').min().dropna().min()\n",
" if self.vmax is None:\n",
" self.vmax = self.data2d.dropna(how='all').max().dropna().max()\n",
" if self.norm is None:\n",
" self.norm = mpl.colors.LogNorm(self.vmin, self.vmax)\n",
" elif self.divergent:\n",
" abs_max = abs(self.data2d.max().max())\n",
" abs_min = abs(self.data2d.min().min())\n",
" vmaxx = max(abs_max, abs_min)\n",
" self.vmin = -vmaxx\n",
" self.vmax = vmaxx\n",
" self.norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)\n",
" else:\n",
" self.vmin = vmin\n",
" self.vmax = vmax\n",
" self.pcolormesh_kws.setdefault('vmin', self.vmin)\n",
" self.pcolormesh_kws.setdefault('vmax', self.vmax)\n",
"\n",
" self.mean = np.mean(self.data2d.values.flat)\n",
" self.colorbar_kws = {} if colorbar_kws is None else colorbar_kws\n",
" self.colorbar_kws.setdefault('fontsize', 14)\n",
" self.colorbar_kws.setdefault('label', '')\n",
" self.colorbar_kws.setdefault('orientation', 'horizontal')\n",
"\n",
" if self.cmap is None:\n",
" self.cmap = mpl.cm.RdBu_r if self.divergent else mpl.cm.YlGnBu\n",
" self.cmap.set_bad('white')\n",
" # Make sure there's no trailing `cmap` or `vmin` or `vmax` values\n",
" if 'cmap' in self.pcolormesh_kws:\n",
" self.pcolormesh_kws.pop('cmap')\n",
" if 'vmin' in self.pcolormesh_kws:\n",
" self.pcolormesh_kws.pop('vmin')\n",
" if 'vmax' in self.pcolormesh_kws:\n",
" self.pcolormesh_kws.pop('vmax')\n",
"\n",
" def validate_data_na_ok(self, data_na_ok):\n",
" if data_na_ok is None:\n",
" self.data2d_na_ok = self.data2d\n",
" else:\n",
" self.data2d_na_ok = data_na_ok\n",
"\n",
" if (self.data2d_na_ok.index != self.data2d.index).any():\n",
" raise ValueError(\n",
" 'data_na_ok must have the exact same indices as the 2d data')\n",
" if (self.data2d_na_ok.columns != self.data2d.columns).any():\n",
" raise ValueError(\n",
" 'data_na_ok must have the exact same columns as the 2d data')\n",
"\n",
" def calculate_linkage(self, values, row=True):\n",
" linkage_function = self.get_linkage_function(values.shape)\n",
"\n",
" if not row:\n",
" values = values.T\n",
"\n",
" if self.use_fastcluster:\n",
" return linkage_function(values, method=self.linkage_method,\n",
" metric=self.metric)\n",
" else:\n",
" from scipy.spatial import distance\n",
"\n",
" pairwise_dists = distance.squareform(\n",
" distance.pdist(values, metric=self.metric))\n",
" return linkage_function(pairwise_dists, method=self.linkage_method)\n",
"\n",
" def get_linkage(self):\n",
" \"\"\"Calculate linkage matrices\n",
"\n",
" These are then passed to the dendrogram functions to plot pairwise\n",
" similarity of samples\n",
" \"\"\"\n",
" if self.color_scale == 'log':\n",
" values = np.log10(self.data2d.values)\n",
" else:\n",
" values = self.data2d.values\n",
"\n",
" if self.row_kws['linkage_matrix'] is None:\n",
"\n",
" self.row_linkage = self.calculate_linkage(values, row=True)\n",
" else:\n",
" self.row_linkage = self.row_kws['linkage_matrix']\n",
"\n",
" if self.col_kws['linkage_matrix'] is None:\n",
" self.col_linkage = self.calculate_linkage(values, row=False)\n",
" else:\n",
" self.col_linkage = self.col_kws['linkage_matrix']\n",
"\n",
" def get_fig_width_ratios(self, side_colors,\n",
" dimension, figsize,\n",
" side_colors_ratio=0.05):\n",
" \"\"\"Get the proportions of the figure taken up by each axes\n",
" \"\"\"\n",
" i = 0 if dimension == 'height' else 1\n",
" if dimension not in ('height', 'width'):\n",
" raise AssertionError(\"{} is not a valid 'dimension' (valid: \"\n",
" \"'height', 'width')\".format(dimension))\n",
" figdim = figsize[i]\n",
" # Get resizing proportion of this figure for the dendrogram and\n",
" # colorbar, so only the heatmap gets bigger but the dendrogram stays\n",
" # the same size.\n",
" dendrogram = min(2. / figdim, .2)\n",
"\n",
" # add the colorbar\n",
" colorbar_width = .8 * dendrogram\n",
" colorbar_height = .2 * dendrogram\n",
" if dimension == 'width':\n",
" ratios = [colorbar_width, colorbar_height]\n",
" else:\n",
" ratios = [colorbar_height, colorbar_width]\n",
"\n",
" if side_colors is not None:\n",
" # Add room for the colors\n",
" ratios += [side_colors_ratio]\n",
"\n",
" # Add the ratio for the heatmap itself\n",
" ratios += [.8]\n",
"\n",
" return ratios\n",
"\n",
" @staticmethod\n",
" def color_list_to_matrix_and_cmap(colors, ind, row=True):\n",
" \"\"\"Turns a list of colors into a numpy matrix and matplotlib colormap\n",
" For 'heatmap()'\n",
" This only works for 1-column color lists..\n",
"\n",
" These arguments can now be plotted using matplotlib.pcolormesh(matrix,\n",
" cmap) and the provided colors will be plotted.\n",
"\n",
" Parameters\n",
" ----------\n",
" colors : list of matplotlib colors\n",
" Colors to label the rows or columns of a dataframe.\n",
" ind : list of ints\n",
" Ordering of the rows or columns, to reorder the original colors\n",
" by the clustered dendrogram order\n",
" row : bool\n",
" Is this to label the rows or columns? Default True.\n",
"\n",
" Returns\n",
" -------\n",
" matrix : numpy.array\n",
" A numpy array of integer values, where each corresponds to a color\n",
" from the originally provided list of colors\n",
" cmap : matplotlib.colors.ListedColormap\n",
"\n",
" \"\"\"\n",
" # TODO: Support multiple color labels on an element in the heatmap\n",
" import matplotlib as mpl\n",
"\n",
" colors_original = colors\n",
" colors = set(colors)\n",
" col_to_value = dict((col, i) for i, col in enumerate(colors))\n",
" matrix = np.array([col_to_value[col] for col in colors_original])[ind]\n",
"\n",
" # Is this row-side or column side?\n",
" if row:\n",
" # shape of matrix: nrows x 1\n",
" new_shape = (len(colors_original), 1)\n",
" else:\n",
" # shape of matrix: 1 x ncols\n",
" new_shape = (1, len(colors_original))\n",
" matrix = matrix.reshape(new_shape)\n",
"\n",
" cmap = mpl.colors.ListedColormap(colors)\n",
" return matrix, cmap\n",
"\n",
" def get_linkage_function(self, shape):\n",
" \"\"\"\n",
" Parameters\n",
" ----------\n",
" shape : tuple\n",
" (nrow, ncol) tuple of the shape of the dataframe\n",
" use_fastcluster : bool\n",
" Whether to use fastcluster (3rd party module) for clustering,\n",
" which is faster than the default scipy.cluster.hierarchy.linkage\n",
" module\n",
"\n",
" Returns\n",
" -------\n",
" linkage_function : function\n",
" Linkage function to use for clustering\n",
"\n",
" .. warning:: If the product of the number of rows and cols exceeds\n",
" 10000, this wil try to import fastcluster, and raise a warning if it\n",
" does not exist. Vanilla scipy.cluster.hierarchy.linkage will take a\n",
" long time on these matrices.\n",
" \"\"\"\n",
" if np.product(shape) >= 10000 or self.use_fastcluster:\n",
" try:\n",
" import fastcluster\n",
"\n",
" self.use_fastcluster = True\n",
" linkage_function = fastcluster.linkage_vector\n",
" except ImportError:\n",
" raise warnings.warn(\n",
" 'Module \"fastcluster\" not found. The dataframe provided '\n",
" 'has shape {}, and one of the dimensions has greater than '\n",
" '1000 variables. Calculating linkage on such a matrix will'\n",
" ' take a long time with vanilla '\n",
" '\"scipy.cluster.hierarchy.linkage\", and we suggest '\n",
" 'fastcluster for such large datasets'.format(shape),\n",
" RuntimeWarning)\n",
" else:\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" linkage_function = sch.linkage\n",
" return linkage_function\n",
"\n",
" def calculate_dendrogram(self, kws, linkage):\n",
" \"\"\"Calculates a dendrogram based on the linkage matrix\n",
"\n",
" Parameters\n",
" ----------\n",
" kws : dict\n",
" Keyword arguments for column or row plotting passed to clusterplot\n",
" linkage : numpy.array\n",
" Linkage matrix, usually created by scipy.cluster.hierarchy.linkage\n",
" orientation : str\n",
" (docstring stolen from scipy.cluster.hierarchy.linkage)\n",
" The direction to plot the dendrogram, which can be any\n",
" of the following strings:\n",
"\n",
" 'top' plots the root at the top, and plot descendent\n",
" links going downwards. (default).\n",
"\n",
" 'bottom'- plots the root at the bottom, and plot descendent\n",
" links going upwards.\n",
"\n",
" 'left'- plots the root at the left, and plot descendent\n",
" links going right.\n",
"\n",
" 'right'- plots the root at the right, and plot descendent\n",
" links going left.\n",
"\n",
" Returns\n",
" -------\n",
" dendrogram : dict\n",
" Dendrogram dictionary as returned by scipy.cluster.hierarchy\n",
" .dendrogram. The important key-value pairing is \"leaves\" which\n",
" tells the ordering of the matrix\n",
" \"\"\"\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" sch.set_link_color_palette(['k'])\n",
"\n",
" if kws['cluster']:\n",
" dendrogram = sch.dendrogram(linkage, **self.dendrogram_kws)\n",
" else:\n",
" dendrogram = {'leaves': list(range(linkage.shape[0] + 1))}\n",
" return dendrogram\n",
"\n",
" def plot_dendrogram(self, ax, dendrogram, row=True):\n",
" \"\"\"Plots a dendrogram on the figure at the gridspec location using\n",
" the linkage matrix\n",
"\n",
" Both the computation and plotting must be in this same function because\n",
" scipy.cluster.hierarchy.dendrogram does ax = plt.gca() and cannot be\n",
" specified its own ax object.\n",
"\n",
" Parameters\n",
" ----------\n",
" ax : matplotlib.axes.Axes\n",
" Axes object upon which the dendrogram is plotted\n",
" \"\"\"\n",
" if row:\n",
" X = dendrogram['dcoord']\n",
" Y = dendrogram['icoord']\n",
" else:\n",
" X = dendrogram['icoord']\n",
" Y = dendrogram['dcoord']\n",
"\n",
" for x, y in zip(X, Y):\n",
" ax.plot(x, y, color='k', linewidth=0.5)\n",
"\n",
" if row:\n",
" ax.invert_xaxis()\n",
" ymax = min(map(min, Y)) + max(map(max, Y))\n",
" ax.set_ylim(0, ymax)\n",
" else:\n",
" xmax = min(map(min, X)) + max(map(max, X))\n",
" ax.set_xlim(0, xmax)\n",
"\n",
" utils.despine(ax=ax, bottom=True, left=True)\n",
" ax.set_axis_bgcolor('white')\n",
" ax.grid(False)\n",
" ax.set_yticks([])\n",
" ax.set_xticks([])\n",
"\n",
" def plot_side_colors(self, ax, kws, dendrogram, row=True):\n",
" \"\"\"Plots color labels between the dendrogram and the heatmap\n",
" Parameters\n",
" ----------\n",
" fig : matplotlib.figure.Figure\n",
" Matplotlib figure instance to plot onto\n",
" kws : dict\n",
" Keyword arguments for column or row plotting passed to clusterplot\n",
" gridspec : matplotlib.gridspec.gridspec\n",
" Indexed gridspec object for where to put the dendrogram plot\n",
" dendrogram : dict\n",
" Dendrogram with key-value 'leaves' as a list of indices in the\n",
" clustered order\n",
" edgecolor : matplotlib color\n",
" Color of the lines outlining each box of the heatmap\n",
" linewidth : float\n",
" Width of the lines outlining each box of the heatmap\n",
"\n",
" Returns\n",
" -------\n",
" ax : matplotlib.axes.Axes\n",
" Axes object, if plotted\n",
" \"\"\"\n",
" # TODO: Allow for array of color labels\n",
" # TODO: allow for groupby and then auto-selecting of colors\n",
" if ax is not None and kws['side_colors'] is not None:\n",
" side_matrix, cmap = self.color_list_to_matrix_and_cmap(\n",
" kws['side_colors'],\n",
" ind=dendrogram['leaves'],\n",
" row=row)\n",
" ax.pcolormesh(side_matrix, cmap=cmap, edgecolor=self.edgecolor,\n",
" linewidth=self.linewidth)\n",
" ax.set_xlim(0, side_matrix.shape[1])\n",
" ax.set_ylim(0, side_matrix.shape[0])\n",
" ax.set_yticks([])\n",
" ax.set_xticks([])\n",
" utils.despine(ax=ax, left=True, bottom=True)\n",
"\n",
" def label_dimension(self, dimension, kws, heatmap_ax, dendrogram_ax,\n",
" dendrogram):\n",
" \"\"\"Label either the rows or columns of a heatmap\n",
" Parameters\n",
" ----------\n",
" dimension : str\n",
" either \"row\" or \"col\", which dimension we are labeling\n",
" kws : dict\n",
" Keyword arguments for the dimension, either row_kws or col_kws from\n",
" clusterplot()\n",
" heatmap_ax : matplotlib.axes.Axes\n",
" Axes object where the heatmap is plotted\n",
" dendrogram_ax : matplotlib.axes.Axes\n",
" Axes object where this dimensions's dendrogram is plotted\n",
" dendrogram : dict\n",
" Dendrogram dictionary with key 'leaves' containing the reordered\n",
" columns or rows after clustering\n",
" data2d : pandas.DataFrame\n",
" Dataframe that we're plotting. Need access to the rownames\n",
" (data2d.index) and columns for the default labeling.\n",
" \"\"\"\n",
" if dimension not in ['row', 'col']:\n",
" raise ValueError('Argument \"dimension\" must be one of \"row\" or '\n",
" '\"col\", not \"{}\"'.format(dimension))\n",
" axis = 0 if dimension == 'row' else 1\n",
"\n",
" # import pdb ;pdb.set_trace()\n",
" # Remove all ticks from the other axes\n",
" dendrogram_ax_axis = dendrogram_ax.yaxis \\\n",
" if dimension == 'row' else dendrogram_ax.xaxis\n",
" dendrogram_ax_axis.set_ticks([])\n",
"\n",
" ax_axis = heatmap_ax.yaxis if dimension == 'row' else heatmap_ax.xaxis\n",
"\n",
" if isinstance(kws['label'], Iterable):\n",
" if len(kws['label']) == self.data2d.shape[axis]:\n",
" ticklabels = kws['label']\n",
" kws['label'] = True\n",
" else:\n",
" raise AssertionError(\n",
" \"Length of '{0}_kws['label']' must be the same as \"\n",
" \"data2d.shape[{1}] (len({0}_kws['label'])={2}, \"\n",
" \"data2d.shape[{1}]={3})\".format(dimension, axis,\n",
" len(kws['label']),\n",
" self.data2d.shape[axis]))\n",
" elif kws['label']:\n",
" ticklabels = self.data2d.index if dimension == 'row' else self \\\n",
" .data2d.columns\n",
" else:\n",
" ax_axis.set_ticklabels([])\n",
"\n",
" if kws['label']:\n",
"\n",
" # Need to set the position first, then the labels because of some\n",
" # odd matplotlib bug\n",
" if dimension == 'row':\n",
" # pass\n",
" ax_axis.set_ticks_position('right')\n",
"\n",
" ticklabels_ordered = [ticklabels[i] for i in\n",
" dendrogram['leaves']]\n",
" ticks = (np.arange(self.data2d.shape[axis]) + 0.5)\n",
" ax_axis.set_ticks(ticks)\n",
" ax_axis.set_ticklabels(ticklabels_ordered)\n",
" heatmap_ax.tick_params(labelsize=self.labelsize)\n",
" if dimension == 'col':\n",
" for label in ax_axis.get_ticklabels():\n",
" label.set_rotation(90)\n",
"\n",
" def plot_heatmap(self):\n",
" \"\"\"Plot the heatmap of the data.\n",
"\n",
" Specifically plots data_na_ok so that user can specify different\n",
" dataframes for the linkage calculation and the plotting.\n",
" \"\"\"\n",
" ax = self.heatmap_ax\n",
" rows_ordered = self.row_dendrogram['leaves']\n",
" cols_ordered = self.col_dendrogram['leaves']\n",
" data_ordered = self.data2d_na_ok.ix[rows_ordered, cols_ordered].values\n",
"\n",
" self.heatmap_ax_pcolormesh = ax.pcolormesh(data_ordered,\n",
" cmap=self.cmap,\n",
" norm=self.norm,\n",
" vmin=self.vmin,\n",
" vmax=self.vmax,\n",
" **self.pcolormesh_kws)\n",
" utils.despine(ax=ax, left=True, bottom=True)\n",
" ax.set_ylim(0, self.data2d.shape[0])\n",
" ax.set_xlim(0, self.data2d.shape[1])\n",
"\n",
" def set_title(self, title, title_fontsize=12):\n",
" \"\"\"Add title if there is one\n",
" \"\"\"\n",
" if title is not None:\n",
" self.col_dendrogram_ax.set_title(title, fontsize=title_fontsize)\n",
"\n",
" def colorbar(self):\n",
" \"\"\"Create the colorbar describing the hue-to-value in the heatmap\n",
" \"\"\"\n",
" ax = self.colorbar_ax\n",
" colorbar_ticklabel_fontsize = self.colorbar_kws.pop('fontsize')\n",
" cb = self.fig.colorbar(self.heatmap_ax_pcolormesh,\n",
" cax=ax, **self.colorbar_kws)\n",
"\n",
" tick_locator = mpl.ticker.MaxNLocator(nbins=2,\n",
" symmetric=self.divergent,\n",
" prune=None, trim=False)\n",
" if 'horizontal'.startswith(self.colorbar_kws['orientation']):\n",
" cb.ax.set_xticklabels(tick_locator.bin_boundaries(self.vmin,\n",
" self.vmax))\n",
" cb.ax.xaxis.set_major_locator(tick_locator)\n",
" else:\n",
" cb.ax.set_yticklabels(tick_locator.bin_boundaries(self.vmin,\n",
" self.vmax))\n",
" cb.ax.yaxis.set_major_locator(tick_locator)\n",
" cb.ax.yaxis.set_ticks_position('right')\n",
"\n",
" # move ticks to left side of colorbar to avoid problems with\n",
" # tight_layout\n",
" # cb.ax.yaxis.set_ticks_position('right')\n",
" if colorbar_ticklabel_fontsize is not None:\n",
" cb.ax.tick_params(labelsize=colorbar_ticklabel_fontsize)\n",
" cb.outline.set_linewidth(0)\n",
"\n",
" def plot_col_side(self):\n",
" \"\"\"Plot the dendrogram and potentially sidecolors for the column\n",
" dimension\n",
" \"\"\"\n",
" if self.col_kws['cluster']:\n",
" self.plot_dendrogram(self.col_dendrogram_ax, self.col_dendrogram,\n",
" row=False)\n",
" else:\n",
" self.col_dendrogram_ax.axis('off')\n",
"\n",
" self.plot_side_colors(self.col_side_colors_ax, self.col_kws,\n",
" self.col_dendrogram, row=False)\n",
"\n",
" def plot_row_side(self):\n",
" \"\"\"Plot the dendrogram and potentially sidecolors for the row dimension\n",
" \"\"\"\n",
" if self.row_kws['cluster']:\n",
" self.plot_dendrogram(self.row_dendrogram_ax, self.row_dendrogram,\n",
" row=True)\n",
" else:\n",
" self.row_dendrogram_ax.axis('off')\n",
" self.plot_side_colors(self.row_side_colors_ax, self.row_kws,\n",
" self.row_dendrogram, row=True)\n",
"\n",
" def label(self):\n",
" \"\"\"Label the rows and columns either at the dendrogram or heatmap\n",
" \"\"\"\n",
" self.label_dimension('row', self.row_kws, self.heatmap_ax,\n",
" self.row_dendrogram_ax, self.row_dendrogram)\n",
"\n",
" self.label_dimension('col', self.col_kws, self.heatmap_ax,\n",
" self.col_dendrogram_ax, self.col_dendrogram)\n",
"\n",
" def plot(self, fig=None, figsize=None, title=None, title_fontsize=12):\n",
" \"\"\"Plot the heatmap!\n",
"\n",
" Parameters\n",
" ----------\n",
" fig : None or matplotlib.figure.Figure instance\n",
" if None, create a new figure, or plot this onto the provided figure\n",
" figsize : None or tuple of ints\n",
" if None, auto-pick the figure size based on the size of the\n",
" dataframe\n",
" title : str\n",
" Title of the plot. Default is no title\n",
" title_fontsize : float\n",
" Fontsize of the title. Default is 12pt\n",
"\n",
" Returns\n",
" -------\n",
"\n",
"\n",
" Raises\n",
" ------\n",
"\n",
" \"\"\"\n",
" self.establish_axes(fig, figsize)\n",
" self.plot_row_side()\n",
" self.plot_col_side()\n",
" self.plot_heatmap()\n",
" self.set_title(title, title_fontsize)\n",
" self.label()\n",
" self.colorbar()\n",
" # Can't just do utils.despine() because it messes up the y-ticklabels\n",
" # utils.despine()\n",
"\n",
" def savefig(self, *args, **kwargs):\n",
" if 'bbox_inches' not in kwargs:\n",
" kwargs['bbox_inches'] \\\n",
" = 'tight'\n",
" self.fig.savefig(*args, **kwargs)\n",
"\n",
"\n",
"def clusteredheatmap(data, pivot_kws=None, title=None, title_fontsize=12,\n",
" color_scale='linear', linkage_method='median',\n",
" metric='euclidean', figsize=None, pcolormesh_kws=None,\n",
" dendrogram_kws=None, row_kws=None, col_kws=None,\n",
" colorbar_kws=None, data_na_ok=None, use_fastcluster=False,\n",
" fig=None, labelsize=18, z_score=None,\n",
" standard_scale=None):\n",
" \"\"\"Plot a hierarchically clustered heatmap of a pandas DataFrame\n",
"\n",
" This is liberally borrowed (with permission) from http://bit.ly/1eWcYWc\n",
" Many thanks to Christopher DeBoever and Mike Lovci for providing\n",
" heatmap/gridspec/colorbar positioning guidance.\n",
"\n",
" Parameters\n",
" ----------\n",
" data: DataFrame\n",
" Data for clustering. Should be a dataframe with no NAs. If you\n",
" still want to plot a dataframe with NAs, provide a non-NA dataframe\n",
" to data (with NAs replaced by 0 or something by your choice) and your\n",
" NA-full dataframe to data_na_ok\n",
" pivot_kws : dict\n",
" If the data is in \"tidy\" format, reshape the data with these pivot\n",
" keyword arguments\n",
" title: string, optional\n",
" Title of the figure. Default None\n",
" title_fontsize: int, optional\n",
" Size of the plot title. Default 12\n",
" color_scale: string, 'log' or 'linear'\n",
" How to scale the colors plotted in the heatmap. Default \"linear\". If\n",
" \"log\", any values 0 or less will become NaNs, as log(0) = -inf.\n",
" linkage_method : string\n",
" Which linkage method to use for calculating clusters.\n",
" See scipy.cluster.hierarchy.linkage documentation for more information:\n",
" http://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html\n",
" Default \"average\"\n",
" metric : string\n",
" Distance metric to use for the data. Default is \"euclidean.\" See\n",
" scipy.spatial.distance.pdist documentation for more options\n",
" http://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html\n",
" figsize: tuple of two ints\n",
" Size of the figure to create. Default is a function of the dataframe\n",
" size.\n",
" pcolormesh_kws : dict\n",
" Keyword arguments to pass to the heatmap pcolormesh plotter. E.g.\n",
" vmin, vmax, cmap, norm. If these are none, they are auto-detected\n",
" from your data. If the data is divergent, i.e. has values both\n",
" above and below zero, then the colormap is blue-red with blue\n",
" as negative values and red as positive. If the data is not\n",
" divergent, then the colormap is YlGnBu.\n",
" Default:\n",
" dict(vmin=None, vmax=None, edgecolor='white', linewidth=0, cmap=None,\n",
" norm=None)\n",
" {row,col}_kws : dict\n",
" Keyword arguments for rows and columns. Can turn of labeling altogether\n",
" with {'label': False}. Can specify you own linkage matrix via\n",
" linkage_matrix. Can also specify side colors labels via\n",
" {'side_color's: colors}, which are useful for evaluating whether\n",
" samples within a group are clustered together.\n",
" Default:\n",
" dict(linkage_matrix=None, cluster=True, label=True, fontsize=None,\n",
" side_colors=None)\n",
" colorbar_kws : dict\n",
" Keyword arguments for the colorbar. The ticklabel fontsize is\n",
" extracted from this dict, then removed.\n",
" dict(fontsize=None, label='values')\n",
" data_na_ok: Dataframe\n",
" If the \"data\" argument has NAs, can supply a separate dataframe to plot\n",
" use_fastcluster: bool\n",
" Whether or not to use the \"fastcluster\" module in Python,\n",
" which calculates linkage several times faster than\n",
" scipy.cluster.hierachy\n",
" Default False except for datasets with more than 1000 rows or columns.\n",
" labelsize : int\n",
" Size of the xticklabels and yticklabels\n",
" z_score : str or None\n",
" Either \"rows\" or \"cols\". Whether or not to calculate z-scores for the\n",
" rows or the columns. Z scores are: z = (x - mean)/std, so values\n",
" in each row (column) will get the mean of the row (column)\n",
" subtracted, then divided by the standard deviation of the row (\n",
" column). This ensures that each row (column) has a mean of 0 and a\n",
" variance of 1.\n",
" standard_scale : str or None\n",
" Either \"rows\" or \"cols\". Whether or not to \"standardize\" that\n",
" dimension, meaning to divide each row (column) by its minimum and\n",
" maximum.\n",
"\n",
" Returns\n",
" -------\n",
" row_dendrogram: dict\n",
" dict with keys 'leaves', 'icoords' (coordinates of the cluster nodes\n",
" along the data, here the y-axis coords), 'dcoords' (coordinates of the\n",
" cluster nodes along the dendrogram height, here the x-axis coords)\n",
" col_dendrogram: dict\n",
" dict with keys 'leaves', 'icoords' (coordinates of the cluster nodes\n",
" along the data, here the x-axis coords), 'dcoords' (coordinates of the\n",
" cluster nodes along the dendrogram height, here the y-axis coords)\n",
"\n",
" Notes\n",
" ----\n",
" To save the figure, make sure to specify \"bbox_inches='tight'\" to `fig\n",
" .savefig` and don't use `fig.tight_layout`, as this will mess up the\n",
" spacing. For example, here's how you can save the figure:\n",
"\n",
" # >>> import matplotlib.pyplot as plt\n",
" # >>> sns.clusteredheatmap(data);\n",
" # >>> fig = plt.gcf()\n",
" # >>> fig.savefig('clusteredheatmap.png', bbox_inches='tight')\n",
" \"\"\"\n",
" plotter = _ClusteredHeatmapPlotter(data, pivot_kws=pivot_kws,\n",
" color_scale=color_scale,\n",
" linkage_method=linkage_method,\n",
" metric=metric,\n",
" pcolormesh_kws=pcolormesh_kws,\n",
" dendrogram_kws=dendrogram_kws,\n",
" row_kws=row_kws,\n",
" col_kws=col_kws,\n",
" colorbar_kws=colorbar_kws,\n",
" use_fastcluster=use_fastcluster,\n",
" data_na_ok=data_na_ok,\n",
" labelsize=labelsize,\n",
" z_score=z_score,\n",
" standard_scale=standard_scale)\n",
"\n",
" plotter.plot(fig, figsize, title, title_fontsize)\n",
" return plotter\n",
" # return plotter.row_dendrogram, plotter.col_dendrogram\n"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import random\n",
"from six.moves import range\n",
"\n",
"import nose.tools as nt\n",
"import numpy.testing as npt\n",
"import pandas.util.testing as pdt\n",
"import scipy\n",
"from numpy.testing.decorators import skipif\n",
"\n",
"from .. import clustering as cl\n",
"from ..palettes import color_palette\n",
"\n",
"try:\n",
" import fastcluster\n",
"\n",
" _no_fastcluster = False\n",
"except ImportError:\n",
" _no_fastcluster = True\n",
"\n",
"\n",
"class TestMatrixPlotter(object):\n",
" shape = (10, 20)\n",
" np.random.seed(2013)\n",
" index = pd.Index(list('abcdefghij'), name='rownames')\n",
" columns = pd.Index(list('ABCDEFGHIJKLMNOPQRST'), name='colnames')\n",
" data2d = pd.DataFrame(np.random.randn(*shape), index=index,\n",
" columns=columns)\n",
" data2d.ix[0:5, 10:20] += 3\n",
" data2d.ix[5:10, 0:5] -= 3\n",
"\n",
" df = pd.melt(data2d.reset_index(), id_vars='rownames')\n",
"\n",
" def test_establish_variables_from_frame(self):\n",
" p = cl._MatrixPlotter()\n",
" p.establish_variables(self.df, pivot_kws=dict(index='rownames',\n",
" columns='colnames',\n",
" values='value'))\n",
" pdt.assert_frame_equal(p.data2d, self.data2d)\n",
" pdt.assert_frame_equal(p.data, self.df)\n",
"\n",
" def test_establish_variables_from_2d(self):\n",
" p = cl._MatrixPlotter()\n",
" p.establish_variables(self.data2d)\n",
" pdt.assert_frame_equal(p.data2d, self.data2d)\n",
" pdt.assert_frame_equal(p.data, self.data2d)\n",
"\n",
"\n",
"class TestClusteredHeatmapPlotter(object):\n",
" shape = (10, 20)\n",
" np.random.seed(2013)\n",
" index = pd.Index(list('abcdefghij'), name='rownames')\n",
" columns = pd.Index(list('ABCDEFGHIJKLMNOPQRST'), name='colnames')\n",
" data2d = pd.DataFrame(np.random.randn(*shape), index=index,\n",
" columns=columns)\n",
" data2d.ix[0:5, 10:20] += 3\n",
" data2d.ix[5:10, 0:5] -= 3\n",
"\n",
" df = pd.melt(data2d.reset_index(), id_vars='rownames')\n",
"\n",
" default_dim_kws = {'linkage_matrix': None, 'side_colors': None,\n",
" 'label': True,\n",
" 'cluster': True, 'fontsize': 18}\n",
"\n",
" default_pcolormesh_kws = {'linewidth': 0.1, 'edgecolor': 'white'}\n",
" default_colorbar_kws = {'fontsize': 14, 'label': 'values',\n",
" 'orientation': 'horizontal'}\n",
"\n",
" default_dendrogram_kws = {'color_threshold': np.inf,\n",
" 'color_list': ['k'],\n",
" 'no_plot': True}\n",
"\n",
" # Side colors for plotting\n",
" colors = color_palette(name='Set2', n_colors=3)\n",
" col_color_inds = np.arange(len(colors))\n",
" col_color_inds = [random.choice(col_color_inds) for _ in\n",
" range(data2d.shape[1])]\n",
" col_colors = [colors[i] for i in col_color_inds]\n",
"\n",
" row_color_inds = np.arange(len(colors))\n",
" row_color_inds = [random.choice(row_color_inds) for _ in\n",
" range(data2d.shape[0])]\n",
" row_colors = [colors[i] for i in row_color_inds]\n",
"\n",
" metric = 'euclidean'\n",
" method = 'median'\n",
"\n",
" def test_interpret_kws_from_none_divergent(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.interpret_kws(row_kws=None, col_kws=None, pcolormesh_kws=None,\n",
" dendrogram_kws=None, colorbar_kws=None)\n",
" pdt.assert_dict_equal(p.row_kws, self.default_dim_kws)\n",
" pdt.assert_dict_equal(p.col_kws, self.default_dim_kws)\n",
"\n",
" nt.assert_equal(p.cmap, mpl.cm.RdBu_r)\n",
" pdt.assert_dict_equal(p.pcolormesh_kws, self.default_pcolormesh_kws)\n",
" pdt.assert_dict_equal(p.colorbar_kws, self.default_colorbar_kws)\n",
"\n",
" abs_max = np.abs(self.data2d.max().max())\n",
" abs_min = np.abs(self.data2d.min().min())\n",
" vmaxx = max(abs_max, abs_min)\n",
" nt.assert_almost_equal(p.vmin, -vmaxx)\n",
" nt.assert_almost_equal(p.vmax, vmaxx)\n",
"\n",
" def test_interpret_kws_from_none_positive(self):\n",
" p = cl._ClusteredHeatmapPlotter(np.abs(self.data2d))\n",
" p.interpret_kws(row_kws=None, col_kws=None, pcolormesh_kws=None,\n",
" dendrogram_kws=None, colorbar_kws=None)\n",
" nt.assert_equal(p.cmap, mpl.cm.YlGnBu)\n",
" nt.assert_is_none(p.norm)\n",
"\n",
" def test_interpret_kws_from_none_log(self):\n",
" p = cl._ClusteredHeatmapPlotter(np.log(self.data2d), color_scale='log')\n",
" p.interpret_kws(row_kws=None, col_kws=None, pcolormesh_kws=None,\n",
" dendrogram_kws=None, colorbar_kws=None)\n",
" nt.assert_is_instance(p.norm, mpl.colors.LogNorm)\n",
" nt.assert_equal(p.cmap, mpl.cm.YlGnBu)\n",
"\n",
" def test_interpret_kws_pcolormesh_cmap(self):\n",
" cmap = mpl.cm.PRGn\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.interpret_kws(row_kws=None, col_kws=None,\n",
" pcolormesh_kws={'cmap': cmap},\n",
" dendrogram_kws=None, colorbar_kws=None)\n",
" nt.assert_equal(p.cmap, cmap)\n",
"\n",
" def test_calculate_linkage_linear(self):\n",
" import scipy.spatial.distance as distance\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" row_pairwise_dists = distance.squareform(\n",
" distance.pdist(self.data2d.values, metric=self.metric))\n",
" row_linkage = sch.linkage(row_pairwise_dists, method=self.method)\n",
"\n",
" col_pairwise_dists = distance.squareform(\n",
" distance.pdist(self.data2d.values.T, metric=self.metric))\n",
" col_linkage = sch.linkage(col_pairwise_dists, method=self.method)\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, use_fastcluster=False)\n",
" p.get_linkage()\n",
" npt.assert_array_almost_equal(p.row_linkage, row_linkage)\n",
" npt.assert_array_almost_equal(p.col_linkage, col_linkage)\n",
"\n",
" def test_calculate_linkage_log(self):\n",
" import scipy.spatial.distance as distance\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" values = np.log10(self.data2d.values)\n",
" row_pairwise_dists = distance.squareform(\n",
" distance.pdist(values, metric=self.metric))\n",
" row_linkage = sch.linkage(row_pairwise_dists, method=self.method)\n",
"\n",
" col_pairwise_dists = distance.squareform(\n",
" distance.pdist(values.T, metric=self.metric))\n",
" col_linkage = sch.linkage(col_pairwise_dists, method=self.method)\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, color_scale='log')\n",
" p.get_linkage()\n",
" npt.assert_array_equal(p.row_linkage, row_linkage)\n",
" npt.assert_array_equal(p.col_linkage, col_linkage)\n",
"\n",
" def test_get_fig_width_ratios_side_colors_none(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" width_ratios = p.get_fig_width_ratios(side_colors=None,\n",
" dimension='width')\n",
" height_ratios = p.get_fig_width_ratios(side_colors=None,\n",
" dimension='height')\n",
"\n",
" dendrogram = .25\n",
" true_height_ratios = [.5 * dendrogram, .5 * dendrogram, 1]\n",
" true_width_ratios = [dendrogram, 1]\n",
"\n",
" npt.assert_array_equal(width_ratios, true_width_ratios)\n",
" npt.assert_array_equal(height_ratios, true_height_ratios)\n",
"\n",
" def test_get_fig_width_ratios_side_colors(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, row_kws={'side_colors': self.row_colors},\n",
" col_kws={'side_colors': self.col_colors})\n",
"\n",
" width_ratios = p.get_fig_width_ratios(side_colors=p.col_kws[\n",
" 'side_colors'], dimension='width')\n",
" height_ratios = p.get_fig_width_ratios(side_colors=p.row_kws[\n",
" 'side_colors'], dimension='height')\n",
"\n",
" dendrogram = .25\n",
" true_height_ratios = [.5 * dendrogram, .5 * dendrogram, .05, 1]\n",
" true_width_ratios = [dendrogram, .05, 1]\n",
"\n",
" npt.assert_array_equal(width_ratios, true_width_ratios)\n",
" npt.assert_array_equal(height_ratios, true_height_ratios)\n",
"\n",
" def test_color_list_to_matrix_and_cmap_row(self):\n",
" import matplotlib as mpl\n",
"\n",
" colors = color_palette(name='Set2', n_colors=3)\n",
" np.random.seed(10)\n",
" n = 10\n",
" ind = np.arange(n)\n",
" color_inds = np.random.choice(np.arange(len(colors)), size=n).tolist()\n",
" color_list = [colors[i] for i in color_inds]\n",
" ind = np.random.shuffle(ind)\n",
"\n",
" colors_original = color_list\n",
" colors = set(colors_original)\n",
" col_to_value = dict((col, i) for i, col in enumerate(colors))\n",
" matrix = np.array([col_to_value[col] for col in colors_original])[ind]\n",
" new_shape = (len(colors_original), 1)\n",
" matrix = matrix.reshape(new_shape)\n",
" cmap = mpl.colors.ListedColormap(colors)\n",
"\n",
" chp = cl._ClusteredHeatmapPlotter\n",
" matrix2, cmap2 = chp.color_list_to_matrix_and_cmap(color_list, ind,\n",
" row=True)\n",
" npt.assert_array_equal(matrix, matrix2)\n",
" npt.assert_array_equal(cmap.colors, cmap2.colors)\n",
"\n",
" def test_color_list_to_matrix_and_cmap_col(self):\n",
" import matplotlib as mpl\n",
"\n",
" colors = color_palette(name='Set2', n_colors=3)\n",
" np.random.seed(10)\n",
" n = 10\n",
" ind = np.arange(n)\n",
" color_inds = np.random.choice(np.arange(len(colors)), size=n).tolist()\n",
" color_list = [colors[i] for i in color_inds]\n",
" ind = np.random.shuffle(ind)\n",
"\n",
" colors_original = color_list\n",
" colors = set(colors_original)\n",
" col_to_value = dict((col, i) for i, col in enumerate(colors))\n",
" matrix = np.array([col_to_value[col] for col in colors_original])[ind]\n",
" new_shape = (1, len(colors_original))\n",
" matrix = matrix.reshape(new_shape)\n",
" cmap = mpl.colors.ListedColormap(colors)\n",
"\n",
" chp = cl._ClusteredHeatmapPlotter\n",
" matrix2, cmap2 = chp.color_list_to_matrix_and_cmap(color_list, ind,\n",
" row=False)\n",
" npt.assert_array_equal(matrix, matrix2)\n",
" npt.assert_array_equal(cmap.colors, cmap2.colors)\n",
"\n",
" def test_get_linkage_function_scipy(self):\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
"\n",
" linkage_function = p.get_linkage_function(shape=self.data2d.shape)\n",
" nt.assert_is_instance(linkage_function, type(sch.linkage))\n",
"\n",
" def test_get_linkage_function_large_data(self):\n",
" try:\n",
" import fastcluster\n",
"\n",
" linkage = fastcluster.linkage_vector\n",
" except ImportError:\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" linkage = sch.linkage\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" linkage_function = p.get_linkage_function(shape=(100, 100))\n",
" nt.assert_is_instance(linkage_function, type(linkage))\n",
"\n",
" @skipif(_no_fastcluster)\n",
" def test_get_linkage_function_fastcluster(self):\n",
" import fastcluster\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, use_fastcluster=True)\n",
" linkage_function = p.get_linkage_function(shape=self.data2d)\n",
" nt.assert_is_instance(linkage_function, type(fastcluster\n",
" .linkage_vector))\n",
"\n",
" def test_calculate_dendrogram(self):\n",
" import scipy.spatial.distance as distance\n",
" import scipy.cluster.hierarchy as sch\n",
"\n",
" sch.set_link_color_palette(['k'])\n",
"\n",
" row_pairwise_dists = distance.squareform(\n",
" distance.pdist(self.data2d.values, metric=self.metric))\n",
" row_linkage = sch.linkage(row_pairwise_dists, method=self.method)\n",
" dendrogram = sch.dendrogram(row_linkage, **self.default_dendrogram_kws)\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" dendrogram2 = p.calculate_dendrogram(p.row_kws, row_linkage)\n",
" npt.assert_equal(dendrogram, dendrogram2)\n",
"\n",
" def test_plot_dendrogram_row(self):\n",
" f, ax = plt.subplots()\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" dendrogram = p.calculate_dendrogram(p.row_kws, p.row_linkage)\n",
" p.plot_dendrogram(ax, dendrogram)\n",
" npt.assert_equal(len(ax.get_lines()), self.data2d.shape[0] - 1)\n",
" plt.close(\"all\")\n",
"\n",
" def test_plot_dendrogram_col(self):\n",
" f, ax = plt.subplots()\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" dendrogram = p.calculate_dendrogram(p.col_kws, p.col_linkage)\n",
" p.plot_dendrogram(ax, dendrogram, row=False)\n",
" npt.assert_equal(len(ax.get_lines()), self.data2d.shape[1] - 1)\n",
" plt.close(\"all\")\n",
"\n",
" def test_plot_side_colors_none(self):\n",
" f, ax = plt.subplots()\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d,\n",
" col_kws={'side_colors': None})\n",
" dendrogram = p.calculate_dendrogram(p.col_kws, p.col_linkage)\n",
" p.plot_side_colors(ax, p.col_kws, dendrogram, row=False)\n",
" npt.assert_equal(len(ax.collections), 0)\n",
" plt.close('all')\n",
"\n",
" def test_plot_side_colors_col(self):\n",
" f, ax = plt.subplots()\n",
" colors = color_palette(name='Set2', n_colors=3)\n",
" np.random.seed(10)\n",
" n = self.data2d.shape[1]\n",
" color_inds = np.random.choice(np.arange(len(colors)), size=n).tolist()\n",
" color_list = [colors[i] for i in color_inds]\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d,\n",
" col_kws={'side_colors': color_list})\n",
" dendrogram = p.calculate_dendrogram(p.col_kws, p.col_linkage)\n",
" p.plot_side_colors(ax, p.col_kws, dendrogram, row=False)\n",
" npt.assert_equal(len(ax.collections), 1)\n",
" plt.close('all')\n",
"\n",
" def test_establish_axes_no_side_colors(self):\n",
" # width = min(self.data2d.shape[1] * 0.5, 40)\n",
" # height = min(self.data2d.shape[0] * 0.5, 40)\n",
" width, height = 10, 10\n",
" figsize = (width, height)\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.establish_axes()\n",
" nt.assert_equal(len(p.fig.axes), 4)\n",
" nt.assert_equal(p.gs.get_geometry(), (3, 2))\n",
" npt.assert_array_equal(p.fig.get_size_inches(), figsize)\n",
" plt.close('all')\n",
"\n",
" def test_establish_axes_no_side_colors_figsize(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
"\n",
" figsize = (8, 22)\n",
" p.establish_axes(figsize=figsize)\n",
"\n",
" nt.assert_equal(len(p.fig.axes), 4)\n",
" npt.assert_array_equal(p.fig.get_size_inches(), figsize)\n",
" plt.close('all')\n",
"\n",
" def test_establish_axes_col_side_colors(self):\n",
" colors = color_palette(name='Set2', n_colors=3)\n",
" np.random.seed(10)\n",
" n = self.data2d.shape[1]\n",
" color_inds = np.random.choice(np.arange(len(colors)), size=n).tolist()\n",
" color_list = [colors[i] for i in color_inds]\n",
"\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d,\n",
" col_kws={'side_colors': color_list})\n",
" p.establish_axes()\n",
"\n",
" nt.assert_equal(len(p.fig.axes), 5)\n",
" nt.assert_equal(p.gs.get_geometry(), (4, 2))\n",
" plt.close('all')\n",
"\n",
" def test_establish_axes_row_side_colors(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, row_kws={'side_colors': self.row_colors})\n",
" p.establish_axes()\n",
"\n",
" nt.assert_equal(len(p.fig.axes), 5)\n",
" nt.assert_equal(p.gs.get_geometry(), (3, 3))\n",
" plt.close('all')\n",
"\n",
" def test_establish_axes_both_side_colors(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, row_kws={'side_colors': self.row_colors},\n",
" col_kws={'side_colors': self.col_colors})\n",
" p.establish_axes()\n",
"\n",
" nt.assert_equal(len(p.fig.axes), 6)\n",
" nt.assert_equal(p.gs.get_geometry(), (4, 3))\n",
" plt.close('all')\n",
"\n",
" def test_plot_heatmap(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
" nt.assert_equal(len(p.heatmap_ax.collections), 1)\n",
" plt.close('all')\n",
"\n",
" def test_set_title(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
"\n",
" title = 'asdf'\n",
" p.set_title(title)\n",
"\n",
" nt.assert_equal(p.col_dendrogram_ax.get_title(), title)\n",
" plt.close('all')\n",
"\n",
" def test_label_dimension_none(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, col_kws=dict(\n",
" label=False), row_kws=dict(label=False))\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
" p.label()\n",
"\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_yticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_yticklabels()), 0)\n",
" plt.close('all')\n",
"\n",
" def test_label_dimension_both(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, col_kws=dict(label=True),\n",
" row_kws=dict(label=True))\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
" p.label()\n",
"\n",
" # Make sure there aren't labels where there aren't supposed to be\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_yticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_yticklabels()), 0)\n",
"\n",
" # Make sure the correct labels are where they're supposed to be\n",
" xticklabels = map(lambda x: x._text,\n",
" p.heatmap_ax.get_xticklabels())\n",
" col_reordered = self.data2d.columns[p.col_dendrogram['leaves']].values\n",
"\n",
" yticklabels = map(lambda x: x._text,\n",
" p.heatmap_ax.get_ymajorticklabels())\n",
" row_reordered = self.data2d.index[p.row_dendrogram['leaves']].values\n",
"\n",
" npt.assert_equal(xticklabels, col_reordered)\n",
" npt.assert_equal(yticklabels, row_reordered)\n",
" plt.close('all')\n",
"\n",
" def test_label_dimension_col(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, row_kws=dict(label=False))\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
" p.label()\n",
"\n",
" # Make sure there aren't labels where there aren't supposed to be\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_yticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_yticklabels()), 0)\n",
"\n",
" # Make sure the correct labels are where they're supposed to be\n",
" xticklabels = map(lambda x: x._text,\n",
" p.heatmap_ax.get_xticklabels())\n",
" col_reordered = self.data2d.columns[p.col_dendrogram['leaves']].values\n",
"\n",
" npt.assert_equal(xticklabels, col_reordered)\n",
" plt.close('all')\n",
"\n",
" def test_label_dimension_row(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, col_kws=dict(label=True),\n",
" row_kws=dict(label=True))\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
" p.plot_col_side()\n",
" p.plot_heatmap()\n",
" p.label()\n",
"\n",
" # Make sure there aren't labels where there aren't supposed to be\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.col_dendrogram_ax.get_yticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_xticklabels()), 0)\n",
" nt.assert_equal(len(p.row_dendrogram_ax.get_yticklabels()), 0)\n",
"\n",
" # Make sure the correct labels are where they're supposed to be\n",
" yticklabels = map(lambda x: x._text,\n",
" p.heatmap_ax.get_ymajorticklabels())\n",
" row_reordered = self.data2d.index[p.row_dendrogram['leaves']].values\n",
"\n",
" npt.assert_equal(yticklabels, row_reordered)\n",
" plt.close('all')\n",
"\n",
" def test_colorbar(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, col_kws=dict(label=True),\n",
" row_kws=dict(label=True))\n",
" p.establish_axes()\n",
" p.plot_heatmap()\n",
" p.colorbar()\n",
"\n",
" nt.assert_equal(len(p.colorbar_ax.collections), 1)\n",
" plt.close('all')\n",
"\n",
" def test_plot_col_side(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.establish_axes()\n",
" p.plot_col_side()\n",
"\n",
" nt.assert_equal(len(p.col_dendrogram_ax.lines),\n",
" len(p.col_dendrogram['icoord']))\n",
" nt.assert_equal(len(p.col_dendrogram_ax.lines),\n",
" len(p.col_dendrogram['dcoord']))\n",
" plt.close('all')\n",
"\n",
" def test_plot_col_side_colors(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, col_kws={'side_colors': self.col_colors})\n",
" p.establish_axes()\n",
" p.plot_col_side()\n",
"\n",
" nt.assert_equal(len(p.col_side_colors_ax.collections), 1)\n",
" # Make sure xlim was set correctly\n",
" nt.assert_equal(p.col_side_colors_ax.get_xlim(),\n",
" tuple(map(float, (0, p.data2d.shape[1]))))\n",
" plt.close('all')\n",
"\n",
" def test_plot_row_side(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, row_kws={'side_colors': self.row_colors})\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
"\n",
" nt.assert_equal(len(p.row_dendrogram_ax.lines),\n",
" len(p.row_dendrogram['icoord']))\n",
" nt.assert_equal(len(p.row_dendrogram_ax.lines),\n",
" len(p.row_dendrogram['dcoord']))\n",
" plt.close('all')\n",
"\n",
" def test_plot_row_side_colors(self):\n",
" p = cl._ClusteredHeatmapPlotter(\n",
" self.data2d, row_kws={'side_colors': self.row_colors})\n",
" p.establish_axes()\n",
" p.plot_row_side()\n",
"\n",
" nt.assert_equal(len(p.row_side_colors_ax.collections), 1)\n",
"\n",
" # Make sure ylim was set correctly\n",
" nt.assert_equal(p.row_side_colors_ax.get_ylim(),\n",
" tuple(map(float, (0, p.data2d.shape[0]))))\n",
" plt.close('all')\n",
"\n",
" def test_plot(self):\n",
" # Should this test ALL the commands that are in\n",
" # ClusteredHeatmapPlotter.plot? If so, how? Seems overkill to do all\n",
" # the tests for each individual part again.\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p.plot()\n",
"\n",
" fig, figsize, title, title_fontsize = None, None, None, None\n",
" p2 = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p2.establish_axes(fig, figsize)\n",
" p2.plot_row_side()\n",
" p2.plot_col_side()\n",
" p2.plot_heatmap()\n",
" p2.set_title(title, title_fontsize)\n",
" p2.label()\n",
" p2.colorbar()\n",
"\n",
" # Check that tight_layout was applied correctly\n",
" nt.assert_equal(p.gs.bottom, p2.gs.bottom)\n",
" nt.assert_equal(p.gs.hspace, p2.gs.hspace)\n",
" nt.assert_equal(p.gs.left, p2.gs.left)\n",
" nt.assert_equal(p.gs.right, p2.gs.right)\n",
" nt.assert_equal(p.gs.top, p2.gs.top)\n",
" nt.assert_equal(p.gs.wspace, p2.gs.wspace)\n",
"\n",
" @skipif(_no_fastcluster)\n",
" def test_plot_fastcluster(self):\n",
" p = cl._ClusteredHeatmapPlotter(self.data2d, use_fastcluster=True)\n",
" p.plot()\n",
"\n",
" fig, figsize, title, title_fontsize = None, None, None, None\n",
" p2 = cl._ClusteredHeatmapPlotter(self.data2d)\n",
" p2.establish_axes(fig, figsize)\n",
" p2.plot_row_side()\n",
" p2.plot_col_side()\n",
" p2.plot_heatmap()\n",
" p2.set_title(title, title_fontsize)\n",
" p2.label()\n",
" p2.colorbar()\n",
"\n",
" # Check that tight_layout was applied correctly\n",
" nt.assert_equal(p.gs.bottom, p2.gs.bottom)\n",
" nt.assert_equal(p.gs.hspace, p2.gs.hspace)\n",
" nt.assert_equal(p.gs.left, p2.gs.left)\n",
" nt.assert_equal(p.gs.right, p2.gs.right)\n",
" nt.assert_equal(p.gs.top, p2.gs.top)\n",
" nt.assert_equal(p.gs.wspace, p2.gs.wspace)\n",
"\n",
"labels = ['Games', 'Minutes', 'Points', 'Field goals made',\n",
" 'Field goal attempts', 'Field goal percentage', 'Free throws made',\n",
" 'Free throws attempts', 'Free throws percentage',\n",
" 'Three-pointers made', 'Three-point attempt',\n",
" 'Three-point percentage', 'Offensive rebounds', 'Defensive rebounds',\n",
" 'Total rebounds', 'Assists', 'Steals', 'Blocks', 'Turnover',\n",
" 'Personal foul']\n"
],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment