Skip to content

Instantly share code, notes, and snippets.

@geotheory
Last active April 2, 2016 02:15
Show Gist options
  • Save geotheory/472b67ad869fa8d6e83d to your computer and use it in GitHub Desktop.
Save geotheory/472b67ad869fa8d6e83d to your computer and use it in GitHub Desktop.
A function to make alluvial-style plots of simple categorical time-series data
require(grid)
require(scales)
require(reshape2)
# Notes:
# dat should be a 3 column data.frame with fields in order as per example below
# wave wavyness of curves defined in terms of x axis data range - experiment to get this right
# ygap gap between items on each y axis
# col a single colour or vector of colours for categories when listed in alphabetical order
# leg.mode if true legend plotted in largest data observation, otherwise custom coordinates (leg.x/y [0,1])
##############
## Function ##
##############
alluvial = function(dat, wave=NA, ygap=1, col=NA, plotdir='up', rankup=T, xmargin=1.1, lab.cex=1,
title=NA, title.cex=1, xaxis.cex=1, grid=F, grid.col='grey80', grid.lwd=1,
leg.mode=T, leg.x=.1, leg.y=.9, leg.cex=1, leg.col='black', leg.lty=NA, leg.lwd=NA,
leg.max=NA, lwd=1, ...){
orig.names = names(dat)
names(dat) = c('item','time','val')
if(is.ordered(dat$time) | is.factor(dat$time)) {
axis.labs = levels(dat$time)
dat$time = as.numeric(dat$time)
} else if(is.numeric(dat$time)) axis.labs = sort(unique(dat$time)) else {
return("Error: time variable must be numeric, factor, or ordered")}
times = sort(unique(dat$time))
dat = dat[order(dat$item),]
datsum = aggregate(val ~ item, dat, mean)
plotorder = order(datsum$val, decreasing=T) # smallest last (on top)
maxval = pretty(max(dat$val))[2] # legend max
# colours
n = length(unique(dat$item))
if(all(is.na(col))) col = substr(rainbow(n), 1, 7)
if(length(col)==1) col = rep(col, n)
if(!length(col)==n) return("Error: 'col' length must equal the number of unique data elements")
# calc vertical gap between items
ymean = mean(dat$val)
ygap = ymean * .1 * ygap
# if not specified (but it really should be)
if(is.na(wave)) wave = .5 * (rev(times)[1] - times[1])/length(times)
plot.y.max = 0 # vertical plot scaling
# prepare main data object
d = list()
for(i in unique(dat$item)) d[[i]] = list()
# loop through periods
for(i in 1:length(times)){
# 3 time variables, NA if they fall outside the data period
if(i>1) t1 = times[i-1] else t1 = NA # prev period
t2 = times[i] # this period
if(i<length(times)) t3 = times[i+1] else t3 = NA # next period
dat.t = dat[dat$time == t2,]
dat.t = if(rankup) dat.t[order(dat.t$val, decreasing=T),] else dat.t[order(dat.t$val, decreasing=F),]
y.sum = sum(dat.t$val) + (nrow(dat.t) * ygap)
if(y.sum > plot.y.max) plot.y.max = y.sum
# loop through items
if(plotdir=='centred') y = -y.sum/2 else y = 0 # vertical scaler
for(j in 1:nrow(dat.t)){
# work up/down y axes to calculate spline positions
y0 = y + ygap
y1 = y0 + dat.t$val[j]
y = y1
# calculate left and right x-axis splines
if(!is.na(t1)) spline.x = c(t2 - wave, t2) else spline.x = numeric(0)
if(!is.na(t3)) spline.x = c(spline.x, t2, t2 + wave)
# update d
item = as.character(dat.t$item[j])
d[[item]]$x = c(d[[item]]$x, spline.x)
d[[item]]$y0 = c(d[[item]]$y0, rep(y0, length(spline.x)))
d[[item]]$y1 = c(d[[item]]$y1, rep(y1, length(spline.x)))
} # end items loop
} # end period loop
# function to ensure vertex arrays are same length, as xspline output can vary
resize = function(v, n=500){d=data.frame(x=1:length(v), y=v); approx(d, xout=seq(1,length(v), length.out=n))$y}
plot.new() # required by xspline
# calculate spline curves
for(i in names(d)){
curves = list()
# iterate through bottom/top sets of bezier points, get curves and resize
for(j in c('y0','y1')) curves[[j]] = lapply(xspline(d[[i]]$x, d[[i]][[j]], shape=1, draw=F), resize)
# stitch top and bottom polylines together clockwise into a polygon
d[[i]]$poly = data.frame(x=c(curves[[1]]$x, rev(curves[[2]]$x)), y=c(curves[[1]]$y, rev(curves[[2]]$y)))
}
# label y positions
labs.l = data.frame(t(sapply(1:length(d), FUN = function(i)t(data.frame(lab=names(d[i]), lefty=mean(c(d[[i]]$y0[1], d[[i]]$y1[1]))) ))), stringsAsFactors=F)
labs.r = data.frame(t(sapply(1:length(d), FUN = function(i)t(data.frame(lab=names(d[i]), lefty=mean(c(rev(d[[i]]$y0)[1], rev(d[[i]]$y1)[1])))) )), stringsAsFactors=F)
names(labs.l) = names(labs.r) = c('item','y')
labs.l$y = as.numeric(as.character(labs.l$y)); labs.r$y = as.numeric(as.character(labs.r$y)) # to numeric
labs.l$col = col[match(labs.l$item, datsum$item)]
labs.r$col = col[match(labs.r$item, datsum$item)]
# line widths
if(any(lwd <= 0)) return("Error: 'lwd' must be greater than zero")
if(length(lwd) == 1) lwd = rep(lwd, length(d))
if(!length(lwd) == length(d)) return("Error: 'lwd' length must equal the number of unique data elements")
lwd = lwd[plotorder]
# scale and orientation of axes
if(plotdir == 'up') ylim = c(0, plot.y.max) else {
if(plotdir == 'down') ylim = c(plot.y.max, 0) else {
if(plotdir == 'centred') ylim = c(-plot.y.max/2, plot.y.max/2) else {
return("Incorrect specification for plotdir: please select 'up', 'down' or 'centred'")
}}}
xran = range(dat$time)
xlim = extendrange(r=xran, f=(xmargin-1)/2)
d = d[plotorder] # reorder to plot largest first
col = col[plotorder] # reorder colours to match
# plot
plot.window(xlim = xlim, ylim = ylim)
axis(1, times, cex.axis=xaxis.cex, labels = axis.labs)
if(is.na(title)) title(paste('Alluvial plot of', orig.names[1], 'vs', orig.names[3]), cex.main=title.cex) else title(title, cex.main=title.cex)
if(grid) abline(v = times, col = grid.col, lwd = grid.lwd)
for(i in 1:length(d)) polygon(d[[i]]$poly$x, d[[i]]$poly$y, col=col[i], lwd=lwd[i], ...)
text(times[1], labs.l$y, labels=labs.l$item, col=labs.l$col, pos=2, cex=lab.cex)
text(times[length(times)], labs.r$y, labels=labs.r$item, adj=0, col=labs.r$col, pos=4, cex=lab.cex)
# legend
topval = max(dat$val)
topitem = as.character(dat$item[match(topval, dat$val)])
toptime = dat$time[match(topval, dat$val)]
if(leg.mode){ # legend plotted on maximum data point
if(is.na(leg.lty)) leg.lty = "dotted"; if(is.na(leg.lwd)) leg.lwd=2
val_ind = match(toptime, d[[topitem]]$x)
leg_y0 = d[[topitem]]$y0[val_ind]
leg_y1 = d[[topitem]]$y1[val_ind]
leg_ym = mean(c(leg_y0, leg_y1))
lines(rep(toptime,2), c(leg_y0, leg_ym-(topval*.08)), lwd=leg.lwd, lend='butt', col=leg.col, lty=leg.lty)
lines(rep(toptime,2), c(leg_y1, leg_ym+(topval*.08)), lwd=leg.lwd, lend='butt', col=leg.col, lty=leg.lty)
text(toptime, leg_ym, labels=formatC(topval,format="d",big.mark=','), pos=NULL, cex=leg.cex, col=leg.col)
} else { # legend plotted in custom position
if(!is.na(leg.max)) maxval = leg.max
if(is.na(leg.lty)) leg.lty = "solid"; if(is.na(leg.lwd)) leg.lwd=10
leg_x = (xlim[2]-xlim[1]) * leg.x + xlim[1]
leg_y = (ylim[2]-ylim[1]) * leg.y + ylim[1]
diffs = d[[topitem]]$y1 - d[[topitem]]$y0
time = d[[topitem]]$x[match(max(diffs), diffs)]
lines(data.frame(x = c(leg_x, leg_x), y = c(leg_y, leg_y + maxval)), lwd=leg.lwd, lend='butt', col=leg.col, lty=leg.lty)
max.lab = formatC(maxval, format="d", big.mark=',')
text(rep(leg_x, 2), c(leg_y, leg_y + maxval), labels = c('0', max.lab), cex=leg.cex, pos=4, offset = 1, col=leg.col)
}
}
##############
## Examples ##
##############
# dummy data
d = USPersonalExpenditure
set.seed(49)
for(i in 1:ncol(d)) d[,i] = sample(d[,i], nrow(d)) # mix up values in-year
d = melt(d)
names(d) = c('category','year','expenditure')
d$expenditure[16] = 91 # for legend clarity place max data point on 1955 axis (from 1960)
head(d)
alluvial(d)
alluvial(d, wave=1, ygap=5, border=NA)
alluvial(d, wave=1, ygap=5, rankup=F)
alluvial(d, wave=1, ygap=5, plotdir='down')
alluvial(d, wave=1, ygap=5, plotdir='centred')
alluvial(d, wave=2.5, ygap=0, border='white', lwd=2, leg.col='white')
alluvial(d, wave=2, ygap=2, grid=T, grid.col='grey90', grid.lwd=6)
alluvial(d, wave=2, ygap=2, xmargin=1.6)
alluvial(d, wave=2, ygap=2, xmargin=1.3, lab.cex=.7, xaxis.cex=.7, title="My alluvial plot")
alluvial(d, leg.mode=F, leg.x = .05, leg.y = .4)
alluvial(d, leg.mode=F, leg.x = .05, leg.y = .6, leg.col='purple', leg.max=50)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment