Skip to content

Instantly share code, notes, and snippets.

@fawda123
Last active May 11, 2022 00:20
Show Gist options
  • Save fawda123/7471137 to your computer and use it in GitHub Desktop.
Save fawda123/7471137 to your computer and use it in GitHub Desktop.
nnet_plot_update
plot.nnet<-function(mod.in,nid=T,all.out=T,all.in=T,bias=T,wts.only=F,rel.rsc=5,
circle.cex=5,node.labs=T,var.labs=T,x.lab=NULL,y.lab=NULL,
line.stag=NULL,struct=NULL,cex.val=1,alpha.val=1,
circle.col='lightblue',pos.col='black',neg.col='grey',
bord.col='lightblue', max.sp = F,...){
require(scales)
#sanity checks
if('mlp' %in% class(mod.in)) warning('Bias layer not applicable for rsnns object')
if('numeric' %in% class(mod.in)){
if(is.null(struct)) stop('Three-element vector required for struct')
if(length(mod.in) != ((struct[1]*struct[2]+struct[2]*struct[3])+(struct[3]+struct[2])))
stop('Incorrect length of weight matrix for given network structure')
}
if('train' %in% class(mod.in)){
if('nnet' %in% class(mod.in$finalModel)){
mod.in<-mod.in$finalModel
warning('Using best nnet model from train output')
}
else stop('Only nnet method can be used with train object')
}
#gets weights for neural network, output is list
#if rescaled argument is true, weights are returned but rescaled based on abs value
nnet.vals<-function(mod.in,nid,rel.rsc,struct.out=struct){
require(scales)
require(reshape)
if('numeric' %in% class(mod.in)){
struct.out<-struct
wts<-mod.in
}
#neuralnet package
if('nn' %in% class(mod.in)){
struct.out<-unlist(lapply(mod.in$weights[[1]],ncol))
struct.out<-struct.out[-length(struct.out)]
struct.out<-c(
length(mod.in$model.list$variables),
struct.out,
length(mod.in$model.list$response)
)
wts<-unlist(mod.in$weights[[1]])
}
#nnet package
if('nnet' %in% class(mod.in)){
struct.out<-mod.in$n
wts<-mod.in$wts
}
#RSNNS package
if('mlp' %in% class(mod.in)){
struct.out<-c(mod.in$nInputs,mod.in$archParams$size,mod.in$nOutputs)
hid.num<-length(struct.out)-2
wts<-mod.in$snnsObject$getCompleteWeightMatrix()
#get all input-hidden and hidden-hidden wts
inps<-wts[grep('Input',row.names(wts)),grep('Hidden_2',colnames(wts)),drop=F]
inps<-melt(rbind(rep(NA,ncol(inps)),inps))$value
uni.hids<-paste0('Hidden_',1+seq(1,hid.num))
for(i in 1:length(uni.hids)){
if(is.na(uni.hids[i+1])) break
tmp<-wts[grep(uni.hids[i],rownames(wts)),grep(uni.hids[i+1],colnames(wts)),drop=F]
inps<-c(inps,melt(rbind(rep(NA,ncol(tmp)),tmp))$value)
}
#get connections from last hidden to output layers
outs<-wts[grep(paste0('Hidden_',hid.num+1),row.names(wts)),grep('Output',colnames(wts)),drop=F]
outs<-rbind(rep(NA,ncol(outs)),outs)
#weight vector for all
wts<-c(inps,melt(outs)$value)
assign('bias',F,envir=environment(nnet.vals))
}
if(nid) wts<-rescale(abs(wts),c(1,rel.rsc))
#convert wts to list with appropriate names
hid.struct<-struct.out[-c(length(struct.out))]
row.nms<-NULL
for(i in 1:length(hid.struct)){
if(is.na(hid.struct[i+1])) break
row.nms<-c(row.nms,rep(paste('hidden',i,seq(1:hid.struct[i+1])),each=1+hid.struct[i]))
}
row.nms<-c(
row.nms,
rep(paste('out',seq(1:struct.out[length(struct.out)])),each=1+struct.out[length(struct.out)-1])
)
out.ls<-data.frame(wts,row.nms)
out.ls$row.nms<-factor(row.nms,levels=unique(row.nms),labels=unique(row.nms))
out.ls<-split(out.ls$wts,f=out.ls$row.nms)
assign('struct',struct.out,envir=environment(nnet.vals))
out.ls
}
wts<-nnet.vals(mod.in,nid=F)
if(wts.only) return(wts)
#circle colors for input, if desired, must be two-vector list, first vector is for input layer
if(is.list(circle.col)){
circle.col.inp<-circle.col[[1]]
circle.col<-circle.col[[2]]
}
else circle.col.inp<-circle.col
#initiate plotting
x.range<-c(0,100)
y.range<-c(0,100)
#these are all proportions from 0-1
if(is.null(line.stag)) line.stag<-0.011*circle.cex/2
layer.x<-seq(0.17,0.9,length=length(struct))
bias.x<-layer.x[-length(layer.x)]+diff(layer.x)/2
bias.y<-0.95
circle.cex<-circle.cex
#get variable names from mod.in object
#change to user input if supplied
if('numeric' %in% class(mod.in)){
x.names<-paste0(rep('X',struct[1]),seq(1:struct[1]))
y.names<-paste0(rep('Y',struct[3]),seq(1:struct[3]))
}
if('mlp' %in% class(mod.in)){
all.names<-mod.in$snnsObject$getUnitDefinitions()
x.names<-all.names[grep('Input',all.names$unitName),'unitName']
y.names<-all.names[grep('Output',all.names$unitName),'unitName']
}
if('nn' %in% class(mod.in)){
x.names<-mod.in$model.list$variables
y.names<-mod.in$model.list$respons
}
if('xNames' %in% names(mod.in)){
x.names<-mod.in$xNames
y.names<-attr(terms(mod.in),'factor')
y.names<-row.names(y.names)[!row.names(y.names) %in% x.names]
}
if(!'xNames' %in% names(mod.in) & 'nnet' %in% class(mod.in)){
if(is.null(mod.in$call$formula)){
x.names<-colnames(eval(mod.in$call$x))
y.names<-colnames(eval(mod.in$call$y))
}
else{
forms<-eval(mod.in$call$formula)
x.names<-mod.in$coefnames
facts<-attr(terms(mod.in),'factors')
y.check<-mod.in$fitted
if(ncol(y.check)>1) y.names<-colnames(y.check)
else y.names<-as.character(forms)[2]
}
}
#change variables names to user sub
if(!is.null(x.lab)){
if(length(x.names) != length(x.lab)) stop('x.lab length not equal to number of input variables')
else x.names<-x.lab
}
if(!is.null(y.lab)){
if(length(y.names) != length(y.lab)) stop('y.lab length not equal to number of output variables')
else y.names<-y.lab
}
#initiate plot
plot(x.range,y.range,type='n',axes=F,ylab='',xlab='',...)
#function for getting y locations for input, hidden, output layers
#input is integer value from 'struct'
get.ys<-function(lyr, max_space = max.sp){
if(max_space){
spacing <- diff(c(0*diff(y.range),0.9*diff(y.range)))/lyr
} else {
spacing<-diff(c(0*diff(y.range),0.9*diff(y.range)))/max(struct)
}
seq(0.5*(diff(y.range)+spacing*(lyr-1)),0.5*(diff(y.range)-spacing*(lyr-1)),
length=lyr)
}
#function for plotting nodes
#'layer' specifies which layer, integer from 'struct'
#'x.loc' indicates x location for layer, integer from 'layer.x'
#'layer.name' is string indicating text to put in node
layer.points<-function(layer,x.loc,layer.name,cex=cex.val){
x<-rep(x.loc*diff(x.range),layer)
y<-get.ys(layer)
points(x,y,pch=21,cex=circle.cex,col=bord.col,bg=in.col)
if(node.labs) text(x,y,paste(layer.name,1:layer,sep=''),cex=cex.val)
if(layer.name=='I' & var.labs) text(x-line.stag*diff(x.range),y,x.names,pos=2,cex=cex.val)
if(layer.name=='O' & var.labs) text(x+line.stag*diff(x.range),y,y.names,pos=4,cex=cex.val)
}
#function for plotting bias points
#'bias.x' is vector of values for x locations
#'bias.y' is vector for y location
#'layer.name' is string indicating text to put in node
bias.points<-function(bias.x,bias.y,layer.name,cex,...){
for(val in 1:length(bias.x)){
points(
diff(x.range)*bias.x[val],
bias.y*diff(y.range),
pch=21,col=bord.col,bg=in.col,cex=circle.cex
)
if(node.labs)
text(
diff(x.range)*bias.x[val],
bias.y*diff(y.range),
paste(layer.name,val,sep=''),
cex=cex.val
)
}
}
#function creates lines colored by direction and width as proportion of magnitude
#use 'all.in' argument if you want to plot connection lines for only a single input node
layer.lines<-function(mod.in,h.layer,layer1=1,layer2=2,out.layer=F,nid,rel.rsc,all.in,pos.col,
neg.col,...){
x0<-rep(layer.x[layer1]*diff(x.range)+line.stag*diff(x.range),struct[layer1])
x1<-rep(layer.x[layer2]*diff(x.range)-line.stag*diff(x.range),struct[layer1])
if(out.layer==T){
y0<-get.ys(struct[layer1])
y1<-rep(get.ys(struct[layer2])[h.layer],struct[layer1])
src.str<-paste('out',h.layer)
wts<-nnet.vals(mod.in,nid=F,rel.rsc)
wts<-wts[grep(src.str,names(wts))][[1]][-1]
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
wts.rs<-wts.rs[grep(src.str,names(wts.rs))][[1]][-1]
cols<-rep(pos.col,struct[layer1])
cols[wts<0]<-neg.col
if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs)
else segments(x0,y0,x1,y1)
}
else{
if(is.logical(all.in)) all.in<-h.layer
else all.in<-which(x.names==all.in)
y0<-rep(get.ys(struct[layer1])[all.in],struct[2])
y1<-get.ys(struct[layer2])
src.str<-paste('hidden',layer1)
wts<-nnet.vals(mod.in,nid=F,rel.rsc)
wts<-unlist(lapply(wts[grep(src.str,names(wts))],function(x) x[all.in+1]))
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
wts.rs<-unlist(lapply(wts.rs[grep(src.str,names(wts.rs))],function(x) x[all.in+1]))
cols<-rep(pos.col,struct[layer2])
cols[wts<0]<-neg.col
if(nid) segments(x0,y0,x1,y1,col=cols,lwd=wts.rs)
else segments(x0,y0,x1,y1)
}
}
bias.lines<-function(bias.x,mod.in,nid,rel.rsc,all.out,pos.col,neg.col,...){
if(is.logical(all.out)) all.out<-1:struct[length(struct)]
else all.out<-which(y.names==all.out)
for(val in 1:length(bias.x)){
wts<-nnet.vals(mod.in,nid=F,rel.rsc)
wts.rs<-nnet.vals(mod.in,nid=T,rel.rsc)
if(val != length(bias.x)){
wts<-wts[grep('out',names(wts),invert=T)]
wts.rs<-wts.rs[grep('out',names(wts.rs),invert=T)]
sel.val<-grep(val,substr(names(wts.rs),8,8))
wts<-wts[sel.val]
wts.rs<-wts.rs[sel.val]
}
else{
wts<-wts[grep('out',names(wts))]
wts.rs<-wts.rs[grep('out',names(wts.rs))]
}
cols<-rep(pos.col,length(wts))
cols[unlist(lapply(wts,function(x) x[1]))<0]<-neg.col
wts.rs<-unlist(lapply(wts.rs,function(x) x[1]))
if(nid==F){
wts.rs<-rep(1,struct[val+1])
cols<-rep('black',struct[val+1])
}
if(val != length(bias.x)){
segments(
rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]),
rep(bias.y*diff(y.range),struct[val+1]),
rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]),
get.ys(struct[val+1]),
lwd=wts.rs,
col=cols
)
}
else{
segments(
rep(diff(x.range)*bias.x[val]+diff(x.range)*line.stag,struct[val+1]),
rep(bias.y*diff(y.range),struct[val+1]),
rep(diff(x.range)*layer.x[val+1]-diff(x.range)*line.stag,struct[val+1]),
get.ys(struct[val+1])[all.out],
lwd=wts.rs[all.out],
col=cols[all.out]
)
}
}
}
#use functions to plot connections between layers
#bias lines
if(bias) bias.lines(bias.x,mod.in,nid=nid,rel.rsc=rel.rsc,all.out=all.out,pos.col=alpha(pos.col,alpha.val),
neg.col=alpha(neg.col,alpha.val))
#layer lines, makes use of arguments to plot all or for individual layers
#starts with input-hidden
#uses 'all.in' argument to plot connection lines for all input nodes or a single node
if(is.logical(all.in)){
mapply(
function(x) layer.lines(mod.in,x,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc,
all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)),
1:struct[1]
)
}
else{
node.in<-which(x.names==all.in)
layer.lines(mod.in,node.in,layer1=1,layer2=2,nid=nid,rel.rsc=rel.rsc,all.in=all.in,
pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val))
}
#connections between hidden layers
lays<-split(c(1,rep(2:(length(struct)-1),each=2),length(struct)),
f=rep(1:(length(struct)-1),each=2))
lays<-lays[-c(1,(length(struct)-1))]
for(lay in lays){
for(node in 1:struct[lay[1]]){
layer.lines(mod.in,node,layer1=lay[1],layer2=lay[2],nid=nid,rel.rsc=rel.rsc,all.in=T,
pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val))
}
}
#lines for hidden-output
#uses 'all.out' argument to plot connection lines for all output nodes or a single node
if(is.logical(all.out))
mapply(
function(x) layer.lines(mod.in,x,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc,
all.in=all.in,pos.col=alpha(pos.col,alpha.val),neg.col=alpha(neg.col,alpha.val)),
1:struct[length(struct)]
)
else{
node.in<-which(y.names==all.out)
layer.lines(mod.in,node.in,layer1=length(struct)-1,layer2=length(struct),out.layer=T,nid=nid,rel.rsc=rel.rsc,
pos.col=pos.col,neg.col=neg.col,all.out=all.out)
}
#use functions to plot nodes
for(i in 1:length(struct)){
in.col<-circle.col
layer.name<-'H'
if(i==1) { layer.name<-'I'; in.col<-circle.col.inp}
if(i==length(struct)) layer.name<-'O'
layer.points(struct[i],layer.x[i],layer.name)
}
if(bias) bias.points(bias.x,bias.y,'B')
}
@leedrake5
Copy link

leedrake5 commented Jan 28, 2019

I love this function - but why restrict train to nnet for train functions? Would it be difficult to adapt it to handle neuralnet via caret::train?

#edit

I got it to work! Very simple fix. terms() does not work on the finalModel, but instead on the whole model. Switching it to that made other train models possible.

I don't have time to do it now (almost midnight), but if you use a placeholder (say mod.o <- mod.in) at the beginning for train objects, you can refer back to that for terms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment