-
-
Save fawda123/7471137 to your computer and use it in GitHub Desktop.
plot.nnet<-function(mod.in,nid=T,all.out=T,all.in=T,bias=T,wts.only=F,rel.rsc=5, | |
circle.cex=5,node.labs=T,var.labs=T,x.lab=NULL,y.lab=NULL, | |
line.stag=NULL,struct=NULL,cex.val=1,alpha.val=1, | |
circle.col='lightblue',pos.col='black',neg.col='grey', | |
bord.col='lightblue', max.sp = F,...){ | |
require(scales) | |
#sanity checks | |
if('mlp' %in% class(mod.in)) warning('Bias layer not applicable for rsnns object') | |
if('numeric' %in% class(mod.in)){ | |
if(is.null(struct)) stop('Three-element vector required for struct') | |
if(length(mod.in) != ((struct[1]*struct[2]+struct[2]*struct[3])+(struct[3]+struct[2]))) | |
stop('Incorrect length of weight matrix for given network structure') | |
} | |
if('train' %in% class(mod.in)){ | |
if('nnet' %in% class(mod.in$finalModel)){ | |
mod.in<-mod.in$finalModel | |
warning('Using best nnet model from train output') | |
} | |
else stop('Only nnet method can be used with train object') | |
} | |
#gets weights for neural network, output is list | |
#if rescaled argument is true, weights are returned but rescaled based on abs value | |
nnet.vals<-function(mod.in,nid,rel.rsc,struct.out=struct){ | |
require(scales) | |
require(reshape) | |
if('numeric' %in% class(mod.in)){ | |
struct.out<-struct | |
wts<-mod.in | |
} | |
#neuralnet package | |
if('nn' %in% class(mod.in)){ | |
struct.out<-unlist(lapply(mod.in$weights[[1]],ncol)) | |
struct.out<-struct.out[-length(struct.out)] | |
struct.out<-c( | |
length(mod.in$model.list$variables), | |
struct.out, | |
length(mod.in$model.list$response) | |
) | |
wts<-unlist(mod.in$weights[[1]]) | |
} | |
#nnet package | |
if('nnet' %in% class(mod.in)){ | |
struct.out<-mod.in$n | |
wts<-mod.in$wts | |
} | |
#RSNNS package | |
if('mlp' %in% class(mod.in)){ | |
struct.out<-c(mod.in$nInputs,mod.in$archParams$size,mod.in$nOutputs) | |
hid.num<-length(struct.out)-2 | |
wts<-mod.in$snnsObject$getCompleteWeightMatrix() | |
#get all input-hidden and hidden-hidden wts | |
inps<-wts[grep('Input',row.names(wts)),grep('Hidden_2',colnames(wts)),drop=F] | |
inps<-melt(rbind(rep(NA,ncol(inps)),inps))$value | |
uni.hids<-paste0('Hidden_',1+seq(1,hid.num)) | |
for(i in 1:length(uni.hids)){ | |
if(is.na(uni.hids[i+1])) break | |
tmp<-wts[grep(uni.hids[i],rownames(wts)),grep(uni.hids[i+1],colnames(wts)),drop=F] | |
inps<-c(inps,melt(rbind(rep(NA,ncol(tmp)),tmp))$value) | |
} | |
#get connections from last hidden to output layers | |
outs<-wts[grep(paste0('Hidden_',hid.num+1),row.names(wts)),grep('Output',colnames(wts)),drop=F] | |
outs<-rbind(rep(NA,ncol(outs)),outs) | |
#weight vector for all | |
wts<-c(inps,melt(outs)$value) | |
assign('bias',F,envir=environment(nnet.vals)) | |
} | |
if(nid) wts<-rescale(abs(wts),c(1,rel.rsc)) | |
#convert wts to list with appropriate names | |
hid.struct<-struct.out[-c(length(struct.out))] | |
row.nms<-NULL | |
for(i in 1:length(hid.struct)){ | |
if(is.na(hid.struct[i+1])) break | |
row.nms<-c(row.nms,rep(paste('hidden',i,seq(1:hid.struct[i+1])),each=1+hid.struct[i])) | |
} | |
row.nms<-c( | |
row.nms, | |
rep(paste('out',seq(1:struct.out[length(struct.out)])),each=1+struct.out[length(struct.out)-1]) | |
) | |
out.ls<-data.frame(wts,row.nms) | |
out.ls$row.nms<-factor(row.nms,levels=unique(row.nms),labels=unique(row.nms)) | |
out.ls<-split(out.ls$wts,f=out.ls$row.nms) | |
assign('struct',struct.out,envir=environment(nnet.vals)) | |
out.ls | |
} | |
wts<-nnet.vals(mod.in,nid=F) | |
if(wts.only) return(wts) | |
#circle colors for input, if desired, must be two-vector list, first vector is for input layer | |
if(is.list(circle.col)){ | |
circle.col.inp<-circle.col[[1]] | |
circle.col<-circle.col[[2]] | |
} | |
else circle.col.inp<-circle.col | |
#initiate plotting | |
x.range<-c(0,100) | |
y.range<-c(0,100) | |
#these are all proportions from 0-1 | |
if(is.null(line.stag)) line.stag<-0.011*circle.cex/2 | |
layer.x<-seq(0.17,0.9,length=length(struct)) | |
bias.x<-layer.x[-length(layer.x)]+diff(layer.x)/2 | |
bias.y<-0.95 | |
circle.cex<-circle.cex | |
#get variable names from mod.in object | |
#change to user input if supplied | |
if('numeric' %in% class(mod.in)){ | |
x.names<-paste0(rep('X',struct[1]),seq(1:struct[1])) | |
y.names<-paste0(rep('Y',struct[3]),seq(1:struct[3])) | |
} | |
if('mlp' %in% class(mod.in)){ | |
all.names<-mod.in$snnsObject$getUnitDefinitions() | |
x.names<-all.names[grep('Input',all.names$unitName),'unitName'] | |
y.names<-all.names[grep('Output',all.names$unitName),'unitName'] | |
} | |
if('nn' %in% class(mod.in)){ | |
x.names<-mod.in$model.list$variables | |
y.names<-mod.in$model.list$respons | |
} | |
if('xNames' %in% names(mod.in)){ | |
x.names<-mod.in$xNames | |
y.names<-attr(terms(mod.in),'factor') | |
y.names<-row.names(y.names)[!row.names(y.names) %in% x.names] | |
} | |
if(!'xNames' %in% names(mod.in) & 'nnet' %in% class(mod.in)){ | |
if(is.null(mod.in$call$formula)){ | |
x.names<-colnames(eval(mod.in$call$x)) | |
y.names<-colnames(eval(mod.in$call$y)) | |
} | |
else{ | |
forms<-eval(mod.in$call$formula) | |
x.names<-mod.in$coefnames | |
facts<-attr(terms(mod.in),'factors') | |
y.check<-mod.in$fitted | |
if(ncol(y.check)>1) y.names<-colnames(y.check) | |
else y.names<-as.character(forms)[2] | |
} | |
} | |
#change variables names to user sub | |
if(!is.null(x.lab)){ | |
if(length(x.names) != length(x.lab)) stop('x.lab length not equal to number of input variables') | |
else x.names<-x.lab | |
} | |
if(!is.null(y.lab)){ | |
if(length(y.names) != length(y.lab)) stop('y.lab length not equal to number of output variables') | |
else y.names<-y.lab | |
} | |
#initiate plot | |
plot(x.range,y.range,type='n',axes=F,ylab='',xlab='',...) | |
#function for getting y locations for input, hidden, output layers | |
#input is integer value from 'struct' | |
get.ys<-function(lyr, max_space = max.sp){ | |
if(max_space){ | |
spacing <- diff(c(0*diff(y.range),0.9*diff(y.range)))/lyr | |
} else { | |
spacing<-diff(c(0*diff(y.range),0.9*diff(y.range)))/max(struct) | |
} | |
seq(0.5*(diff(y.range)+spacing*(lyr-1)),0.5*(diff(y.range)-spacing*(lyr-1)), | |
length=lyr) | |
} | |
#function for plotting nodes | |
#'layer' specifies which layer, integer from 'struct' | |
#'x.loc' indicates x location for layer, integer from 'layer.x' | |
#'layer.name' is string indicating text to put in node | |
layer.points<-function(layer,x.loc,layer.name,cex=cex.val){ | |
x<-rep(x.loc*diff(x.range),layer) | |
y<-get.ys(layer) | |
points(x,y,pch=21,cex=circle.cex,col=bord.col,bg=in.col) | |
if(node.labs) text(x,y,paste(layer.name,1:layer,sep=''),cex=cex.val) | |
if(layer.name=='I' & var.labs) text(x-line.stag*diff(x.range),y,x.names,pos=2,cex=cex.val) | |
if(layer.name=='O' & var.labs) text(x+line.stag*diff(x.range),y,y.names,pos=4,cex=cex.val) | |
} | |
#function for plotting bias points | |
#'bias.x' is vector of values for x locations | |
#'bias.y' is vector for y location | |
#'layer.name' is string indicating text to put in node | |
bias.points<-function(bias.x,bias.y,layer.name,cex,...){ | |
for(val in 1:length(bias.x)){ | |
points( | |
diff(x.range)*bias.x[val], | |
bias.y*diff(y.range), | |
pch=21,col=bord.col,bg=in.col,cex=circle.cex | |
) | |
if(node.labs) | |
text( | |
diff(x.range)*bias.x[val], | |
bias.y*diff(y.range), | |
paste(layer.name,val,sep=''), | |
cex=cex.val | |
) | |
} | |
} | |
#function creates lines colored by direction and width as proportion of magnitude | |
#use 'all.in' argument if you want to plot connection lines for only a single input node | |
layer.lines<-function(mod.in,h.layer,layer1=1,layer2=2,out.layer=F,nid,rel.rsc,all.in,pos.col, | |
neg.col,...){ | |
x0<-rep(layer.x[layer1]*diff(x.range)+line.stag*diff(x.range),struct[layer1]) | |
x1<-rep(layer.x[layer2]*diff(x.range)-line.stag*diff(x.range),struct[layer1]) | |
if(out.layer==T){ | |
y0<-get.ys(struct[layer1]) | |
y1<-rep(get.ys(struct[layer2])[h.layer],struct[layer1]) | |
src.str<-paste('out',h.layer) | |
wts<-nnet.vals(mod.in,nid=F,rel.rsc) | |
wts<-wts[grep(src.str,names(wts))][[1]][-1] | |
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc) | |
wts.rs<-wts.rs[grep(src.str,names(wts.rs))][[1]][-1] | |
cols<-rep(pos.col,struct[layer1]) | |
cols[wts<0]<-neg.col | |
if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs) | |
else segments(x0,y0,x1,y1) | |
} | |
else{ | |
if(is.logical(all.in)) all.in<-h.layer | |
else all.in<-which(x.names==all.in) | |
y0<-rep(get.ys(struct[layer1])[all.in],struct[2]) | |
y1<-get.ys(struct[layer2]) | |
src.str<-paste('hidden',layer1) | |
wts<-nnet.vals(mod.in,nid=F,rel.rsc) | |
wts<-unlist(lapply(wts[grep(src.str,names(wts))],function(x) x[all.in+1])) | |
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc) | |
wts.rs<-unlist(lapply(wts.rs[grep(src.str,names(wts.rs))],function(x) x[all.in+1])) | |
cols<-rep(pos.col,struct[layer2]) | |
cols[wts<0]<-neg.col | |
if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs) | |
else segments(x0,y0,x1,y1) | |
} | |
} | |
bias.lines<-function(bias.x,mod.in,nid,rel.rsc,all.out,pos.col,neg.col,...){ | |
if(is.logical(all.out)) all.out<-1:struct[length(struct)] | |
else all.out<-which(y.names==all.out) | |
for(val in 1:length(bias.x)){ | |
wts<-nnet.vals(mod.in,nid=F,rel.rsc) | |
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc) | |
if(val != length(bias.x)){ | |
wts<-wts[grep('out',names(wts),invert=T)] | |
wts.rs<-wts.rs[grep('out',names(wts.rs),invert=T)] | |
sel.val<-grep(val,substr(names(wts.rs),8,8)) | |
wts<-wts[sel.val] | |
wts.rs<-wts.rs[sel.val] | |
} | |
else{ | |
wts<-wts[grep('out',names(wts))] | |
wts.rs<-wts.rs[grep('out',names(wts.rs))] | |
} | |
cols<-rep(pos.col,length(wts)) | |
cols[unlist(lapply(wts,function(x) x[1]))<0]<-neg.col | |
wts.rs<-unlist(lapply(wts.rs,function(x) x[1])) | |
if(nid==F){ | |
wts.rs<-rep(1,struct[val+1]) | |
cols<-rep('black',struct[val+1]) | |
} | |
if(val != length(bias.x)){ | |
segments( | |
rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]), | |
rep(bias.y*diff(y.range),struct[val+1]), | |
rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]), | |
get.ys(struct[val+1]), | |
lwd=wts.rs, | |
col=cols | |
) | |
} | |
else{ | |
segments( | |
rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]), | |
rep(bias.y*diff(y.range),struct[val+1]), | |
rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]), | |
get.ys(struct[val+1])[all.out], | |
lwd=wts.rs[all.out], | |
col=cols[all.out] | |
) | |
} | |
} | |
} | |
#use functions to plot connections between layers | |
#bias lines | |
if(bias) bias.lines(bias.x,mod.in,nid=nid,rel.rsc=rel.rsc,all.out=all.out,pos.col=alpha(pos.col,alpha.val), | |
neg.col=alpha(neg.col,alpha.val)) | |
#layer lines, makes use of arguments to plot all or for individual layers | |
#starts with input-hidden | |
#uses 'all.in' argument to plot connection lines for all input nodes or a single node | |
if(is.logical(all.in)){ | |
mapply( | |
function(x) layer.lines(mod.in,x,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc, | |
all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)), | |
1:struct[1] | |
) | |
} | |
else{ | |
node.in<-which(x.names==all.in) | |
layer.lines(mod.in,node.in,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc,all.in=all.in, | |
pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)) | |
} | |
#connections between hidden layers | |
lays<-split(c(1,rep(2:(length(struct)-1),each=2),length(struct)), | |
f=rep(1:(length(struct)-1),each=2)) | |
lays<-lays[-c(1,(length(struct)-1))] | |
for(lay in lays){ | |
for(node in 1:struct[lay[1]]){ | |
layer.lines(mod.in,node,layer1=lay[1],layer2=lay[2],nid=nid,rel.rsc=rel.rsc,all.in=T, | |
pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)) | |
} | |
} | |
#lines for hidden-output | |
#uses 'all.out' argument to plot connection lines for all output nodes or a single node | |
if(is.logical(all.out)) | |
mapply( | |
function(x) layer.lines(mod.in,x,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc, | |
all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)), | |
1:struct[length(struct)] | |
) | |
else{ | |
node.in<-which(y.names==all.out) | |
layer.lines(mod.in,node.in,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc, | |
pos.col=pos.col,neg.col=neg.col,all.out=all.out) | |
} | |
#use functions to plot nodes | |
for(i in 1:length(struct)){ | |
in.col<-circle.col | |
layer.name<-'H' | |
if(i==1) { layer.name<-'I'; in.col<-circle.col.inp} | |
if(i==length(struct)) layer.name<-'O' | |
layer.points(struct[i],layer.x[i],layer.name) | |
} | |
if(bias) bias.points(bias.x,bias.y,'B') | |
} |
No license, use at will! The idea was inspired by Ozesmi and Ozesmi 1999, cited in my original blog:
http://beckmw.wordpress.com/2013/03/04/visualizing-neural-networks-from-the-nnet-package/
Beautiful, thanks for your work.
I modified the function to allow custom border colors for the neurons (which is important if, for example, you want your neurons to have white background and just a black border for printing). See changes in my fork: https://gist.github.com/Peque/41a9e20d6687f2f3108d/revisions
Peque, thanks for the suggestion. I completely forgot about bw printing. I merged the changes you made.
Hi, I have a nnet model trained using carat with 10 layers and model averaging. The function returns the error stop("'to' must be length 1). Does this function work when usign model averaging?
Glenn, try it with the NeuralNetTools CRAN release.
Thanks for this. This is great!
I love this function - but why restrict train to nnet for train functions? Would it be difficult to adapt it to handle neuralnet via caret::train?
#edit
I got it to work! Very simple fix. terms() does not work on the finalModel, but instead on the whole model. Switching it to that made other train models possible.
I don't have time to do it now (almost midnight), but if you use a placeholder (say mod.o <- mod.in) at the beginning for train objects, you can refer back to that for terms.
Hello, a nice function. Could you tell me what is the license for this code?