Skip to content

Instantly share code, notes, and snippets.

@statgeek
Last active May 12, 2023 16:33
Show Gist options
  • Save statgeek/b5073195950d9d82d11d627b660accb2 to your computer and use it in GitHub Desktop.
Save statgeek/b5073195950d9d82d11d627b660accb2 to your computer and use it in GitHub Desktop.
SAS - Macro to create a Sankey diagram
/*
Author: Jeff Myers
Source: https://communities.sas.com/t5/Graphics-Programming/Sankey-Diagram-Decision-Tree-etc/m-p/719812#
*/
data random;
call streaminit(123);
do id = 1 to 100;
do cycle=1 to 5;
u = rand("Uniform");
grade = floor(6*u);
output;
if grade=5 then cycle=5;
end;
end;
drop u:;
label cycle='Cycle' grade='Maximum Grade AE';
run;
option mprint;
%macro sankey(
/*Required*/
data= /*Dataset to input*/,
id= /*Patient/traveler ID*/,
nodes= /*Nodes are the different points or time-points that the curves connect between. Needs to be numeric*/,
group= /*Different subgroups for the node sections*/,
/*Optional for tweaking the graph*/
barwidth=5 /*Controls width of the Nodes*/,
bargap=5 /*Allocates a percentage of the y-space for gaps between GROUPs*/,
points=20 /*Number of points used to draw Bezier curves. More=smoother lines but more memory*/,
curve_rectangle_gap=0 /*Gap as a percentage between the ends of the connecting curves and the Nodes*/,
antialias=200000 /*When a lot of points are used antialias will need to be increased to get smooth curves*/,
width=16in /*Determines the width of the image*/,
height=8in /*Determines the height of the image*/,
plotname=_sankey /*Determines the name of the image. ODS LISTING should be turned on to save image*/,
font_size=12pt /*Determines the font size*/);
/**Rename Variables and create a temporary dataset**/
data _temp;
merge &data (keep=&id rename=(&id=id))
&data (keep=&nodes rename=(&nodes=nodes))
&data (keep=&group rename=(&group=group));
run;
/**Grab Labels for Group and Nodes variables**/
data _null_;
set _temp (obs=1);
%local group_label nodes_label;
call symput('nodes_label',strip(vlabel(nodes)));
call symput('group_label',strip(vlabel(group)));
run;
/**Grab cross tab frequency of groups and nodes for rectangles**/
proc freq data=_temp noprint;
table nodes*group / outpct out=_frq (keep=nodes group count pct_row);
run;
/**Assign a group level value for use in arrays later**/
data _levels;
set _frq;
by nodes group;
if first.nodes then group_lvl=0;
group_lvl+1;
run;
/**Grab number and values of unique Node values**/
proc sql noprint;
%local n_nodes i null;
select count(distinct nodes) into :n_nodes separated by ''
from _frq;
select distinct nodes format=12. into :node1-
from _frq;
%do i = 1 %to &n_nodes;
%local n_group&i ;
%end;
/**Count how many groups are in each Node**/
select nodes,count(distinct group) into :null,:n_group1-
from _frq group by nodes;
quit;
/*Create coordinates for rectangles**/
/*BARWIDTH controls the width of rectangles, BARGAP assigns a percentage for white space*/
data _rectangles;
set _frq;
by nodes group;
array _node_n {&n_nodes}
(%do i = 1 %to &n_nodes;
%if &i>1 %then %do; , %end;
&&n_group&i
%end;);
if first.nodes then do;
last_y=0;nodes_count+1;group_count=0;
end;
if first.group then do;
last_x=100*(nodes_count-1)/(&n_nodes-1);
group_count+1;
end;
retain last_x last_y;
rectangle_id=catx('-',nodes_count,group_count);
x=last_x;y=last_y;output;
x=last_x+&barwidth;y=last_y;output;
x=last_x+&barwidth;y=last_y+((100-&bargap)/100)*pct_row;output;
x=last_x;y=last_y+((100-&bargap)/100)*pct_row;output;
last_y=y+&bargap/_node_n(nodes_count);
drop _node:;
run;
/*Find the unique paths going out of each node between groups*/
proc sort data=_temp;
by id;
data _paths;
set _temp;
by id;
array node_ {&n_nodes};
retain node_;
if first.id then call missing(of node_(*));
%do i=1 %to &n_nodes;
%if &i>1 %then %do; else %end;
if nodes=&&node&i then node_lvl=&i;
%end;
node_(node_lvl)=group;
if last.id then do;
do i=1 to dim(node_)-1;
if ^missing(node_(i)) then do;
start=node_(i);
starting_node=i;
do j = i+1 to dim(node_);
if ^missing(node_(j)) then do;
end=node_(j);
ending_node=j;
output;
j=dim(node_);
end;
end;
end;
end;
end;
keep id start starting_node end ending_node node_:;
run;
/**Get counts for each path**/
proc sort data=_paths;
by starting_node ending_node start end;
run;
proc freq data=_paths noprint;
by starting_node ending_node;
table start*end / list missing out=_frq2;
run;
/**Grab counts and values needed to create the Connecting Curves**/
proc sql noprint;
create table _paths2 as
select
/**Numbers for starting groups**/
a.starting_node,a.start,c.count as start_group_n,
/*Grab location and height of rectangles*/
c2.y_min as start_group_min,c2.y_max as start_group_max,c2.y_max-c2.y_min as start_group_diff,c2.x_max+&curve_rectangle_gap as start_group_x,
c3.group_lvl as start_index, /*Used for arrays*/
/**Numbers for ending groups**/
a.ending_node,a.end,e.count as end_group_n,
/*Grab location and height of rectangles*/
e2.y_min as end_group_min,e2.y_max as end_group_max,e2.y_max-e2.y_min as end_group_diff,e2.x_min-&curve_rectangle_gap as end_group_x,
e3.group_lvl as end_index,/*Used for arrays*/
a.count as n_move /*Number of patients in the current path*/
from _frq2 a
left join _frq c on a.starting_node=c.nodes and a.start=c.group
left join _frq e on a.ending_node=e.nodes and a.end=e.group
left join (select nodes,group,min(y) as y_min,max(y) as y_max, max(x) as x_max
from _rectangles group by nodes,group) as c2
on a.starting_node=c2.nodes and a.start=c2.group
left join (select nodes,group,min(y) as y_min,max(y) as y_max,min(x) as x_min
from _rectangles group by nodes,group) as e2
on a.ending_node=e2.nodes and a.end=e2.group
left join _levels c3 on a.starting_node=c3.nodes and a.start=c3.group
left join _levels e3 on a.ending_node=e3.nodes and a.end=e3.group
order by starting_node,ending_node,start,end;
/**Grab number of distinct group values for array**/
%local n_grps;
select count(distinct end) into :n_grps separated by ''
from _paths2;
/**Grab x-axis location for node labels**/
create table _node_labels as
select nodes, (max(x)+min(x))/2 as x_label from _rectangles group by nodes;
quit;
data _paths3;
set _paths2;
by starting_node ending_node start end;
/**Hold running totals for group counts**/
array start_n {&n_nodes,&n_grps} (%sysevalf(&n_nodes*&n_grps)*0) ;
array end_n {&n_nodes,&n_grps} (%sysevalf(&n_nodes*&n_grps)*0) ;
/***Build the Bezier Curve Connectors: Use Cubic Bezier Equation:
B(t)=(1-t)^3*P0 + 3(1-t)^2*t*P1 + 3(1-t)*t^2*P2 + t^3*P3, 0 <= t <= 1***/
length path_index $20.;
path_index=catx('-',starting_node,start_index,end_index);
/*Find y-axis values for start/end corners of Bezier curves*/
start_y1=start_group_min+start_group_diff*(start_n(starting_node,start_index)/start_group_n);
start_y2=start_y1+start_group_diff*(n_move/start_group_n);
end_y1=end_group_min+end_group_diff*(end_n(ending_node,end_index)/end_group_n)+end_group_diff*(n_move/end_group_n);
end_y2=end_y1-end_group_diff*(n_move/end_group_n);
/**Bezier Curve 1: From Left group to Right Group path**/
do t = 0 to 1 by 1/25;
x=((1-t)**3)*start_group_x+
3*((1-t)**2)*t*(start_group_x+(end_group_x-start_group_x)/3)+
3*(1-t)*(t**2)*(start_group_x+2*(end_group_x-start_group_x)/3)+
(t**3)*end_group_x;
y=((1-t)**3)*start_y2+3*((1-t)**2)*t*start_y2+3*(1-t)*(t**2)*end_y1+(t**3)*end_y1;
output;
end;
/**Bezier Curve 2: From Right Group back to Left Group**/
do t = 0 to 1 by 1/25;
x=((1-t)**3)*end_group_x+
3*((1-t)**2)*t*(start_group_x+2*(end_group_x-start_group_x)/3)+
3*(1-t)*(t**2)*(start_group_x+(end_group_x-start_group_x)/3)+
(t**3)*start_group_x;
y=((1-t)**3)*end_y2+3*((1-t)**2)*t*end_y2+3*(1-t)*(t**2)*start_y1+(t**3)*start_y1;
output;
end;
/**Increase running total for each groups N**/
start_n(starting_node,start_index)+n_move;
end_n(ending_node,end_index)+n_move;
keep x y path_index start;
run;
/**Combine Data for Plot**/
data _plot;
set _rectangles (keep=x y rectangle_id group rename=(rectangle_id=id))
_paths3 (keep=x y path_index start rename=(path_index=id start=group))
_node_labels;
run;
ods graphics /reset width=&width height=&height ANTIALIASMAX=&antialias imagename="&plotname";
proc sgplot data=_plot noborder;
polygon x=x y=y id=id / fill nooutline group=group transparency=0.3 name='p';
xaxistable nodes / x=x_label location=inside position=top nolabel title="&nodes_label" valueattrs=(size=&font_size) titleattrs=(size=&font_size) ;
xaxis min=0 max=%sysevalf(100+&barwidth) values=(0 to 100 by 10) display=none valueshint;
yaxis min=0 max=100 reverse values=(0 to 100 by 10) display=none;
keylegend 'p'/ title="&group_label" noborder location=outside position=bottom exclude=('') valueattrs=(size=&font_size) titleattrs=(size=&font_size);
run;
%mend;
%sankey(data=random,group=grade,id=id,nodes=cycle);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment