Skip to content

Instantly share code, notes, and snippets.

@nicoguaro
Last active July 30, 2016 16:46
Show Gist options
  • Select an option

  • Save nicoguaro/7aa43a8950868b9d0f36ed2a64da15fc to your computer and use it in GitHub Desktop.

Select an option

Save nicoguaro/7aa43a8950868b9d0f36ed2a64da15fc to your computer and use it in GitHub Desktop.
Plot a matrix with different labels in each cell as a series of circles with colors that represent each label.
# -*- coding: utf-8 -*-
"""
Plot a matrix with different labels in each cell as a series
of circles with colors that represent each label.
@author: Nicolas Guarin-Zapata
"""
from __future__ import division, print_function
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
plt.style.use("seaborn-white")
fig = plt.figure()
ax = plt.subplot(111)
# Data and labels loading
data = np.loadtxt("Vieira_data.csv", dtype=str, delimiter=";")
nrows, ncols = data.shape
colors = {'COA': "#e41a1c",
'RAG': "#377eb8",
'CHK': "#4daf4a",
'VAR': "#984ea3",
'DAT': "#ff7f00",
'EXE': "#ffff33"}
# Plot circles
for row in range(nrows):
for col in range(ncols):
cell_content = map(str.strip, data[row, col].split(','))
hor_space = 1/(len(cell_content) + 1)
for cont, label in enumerate(cell_content):
ax.plot(col + (cont + 1)*hor_space, nrows - row - 0.5,
'o', ms=40, color=colors[label], label=label)
# Plot grid
for row in range(nrows):
for col in range(ncols):
ax.plot([0, ncols], [row, row], color="black")
ax.plot([col, col], [0, nrows], color="black")
# Labels
xticks = [k + 0.5 for k in range(ncols)]
xlabels = ["Section %i"%k for k in range(1, ncols + 1)]
yticks = [nrows - k - 0.5 for k in range(nrows)]
ylabels = ["Student %i"%k for k in range(1, nrows + 1)]
plt.xticks(xticks, xlabels)
plt.yticks(yticks, ylabels)
plt.tick_params(labeltop=True, labelbottom=False)
# Legend
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), markerscale=0.4,
loc='center left', bbox_to_anchor=(1, 0.5))
plt.xlim(0, ncols)
plt.ylim(0, nrows)
plt.savefig("Matrix_plot.png", dpi=300, bbox_inches="tight")
plt.savefig("Matrix_plot.svg", bbox_inches="tight")
plt.show()
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
We can make this file beautiful and searchable if this error is corrected: It looks like row 2 should actually have 3 columns, instead of 2 in line 1.
COA, RAG, EXE; CHK
VAR; COA, VAR
VAR, DAT; EXE
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment