Skip to content

Instantly share code, notes, and snippets.

@jayyonamine
Created March 14, 2015 17:00
Show Gist options
  • Save jayyonamine/64c70860dfe4400eec48 to your computer and use it in GitHub Desktop.
Save jayyonamine/64c70860dfe4400eec48 to your computer and use it in GitHub Desktop.
plot_partial.r
library('ggplot2')
library('randomForest')
set.seed(2014)
rf_predict<-function(rf_object, data){
if (rf_object$type=="classification"){
p <-predict(rf_object, data, type="prob")
p<-as.vector(p[,2])
} else {
p <-predict(rf_object, data)
}
return (p)
}
plot_partial<-
function(rf, data, dv, iv, conf_int_lb=.25,
conf_int_ub=.75, range_low=NULL,
range_high=NULL, delta=FALSE, num_sample=NULL)
{
iv_name<-substitute(iv)
dv_name<-substitute(dv)
if (is.factor(data[, iv_name])==TRUE){
factor_var<-unique(iris[, iv_name])
#the test set needs all factor levels. so, we build them and will drop them before we plot
factor_names <- attributes(factor_var)$levels
fix_factor_df<-data.frame(data[1:length(factor_names),])
fix_factor_df[, iv_name]<-factor_names
y_hat_df <- data.frame(matrix(vector(),0, 2))
y_temp <- data.frame(matrix(vector(), nrow(data), 2))
y<-predict(rf, data)
for (i in 1:length(factor_names)){
data[, iv_name] <- factor_names[i]
data[, iv_name] <- factor(data[, iv_name])
data_temp<-rbind(data, fix_factor_df)
p<-rf_predict(rf, data_temp)
y_temp[,1]<-p[1:nrow(data)] #drop the fix_factor_df rows
if (delta==TRUE){
y_temp[,1]<-y_temp[,1]-y
}
y_temp[,2]<-factor_names[i]
y_hat_df<-rbind(y_hat_df, y_temp)
##no need for UB and LB since the boxplot automatically generates it
}
plot<- qplot(y_hat_df[,2], y_hat_df[,1],
data = y_hat_df,
geom="boxplot",
main = paste("Partial Dependence of", (iv_name), "on", (dv_name))) +
ylab(bquote("Predicted values of" ~ .(dv_name))) +
xlab(iv_name)
return (plot)
} else {
conf_int <-(conf_int_ub-conf_int_lb)*100
temp<-sort(data[, iv_name])
if (is.null(num_sample)==FALSE){
temp<-sample(temp, num_sample)
}
if (is.null(range_low)==FALSE & is.null(range_high)==FALSE){
low_value<-quantile(temp, range_low)
high_value<-quantile(temp, range_high)
temp<-temp[temp<high_value & temp>low_value]
}
y_hat_mean<-vector()
y_hat_lb<-vector()
y_hat_ub<-vector()
y<-rf_predict(rf, data)
for (i in 1:length(temp)){
data[, iv_name] <- temp[i]
y_hat<-rf_predict(rf, data)
if (delta==TRUE){
y_hat<-y_hat-y
}
y_hat_mean[i]<-weighted.mean(y_hat)
y_hat_lb[i]<-quantile(y_hat, conf_int_lb)
y_hat_ub[i]<-quantile(y_hat, conf_int_ub)
}
df_new<-as.data.frame(cbind(temp, y_hat_mean, y_hat_lb, y_hat_ub))
plot<- ggplot(df_new, aes(temp)) +
geom_line(aes(y=y_hat_mean), colour="blue") +
geom_ribbon(aes(ymin=y_hat_lb, ymax=y_hat_ub), alpha=0.2) +
geom_rug(aes()) +
xlab(iv_name) +
ylab(bquote("Predicted values of" ~ .(dv_name))) +
ggtitle(paste("Partial Dependence of", (iv_name), "on", (dv_name), "\n with", (conf_int), "% Confidence Intervals"))
return (plot)
}
}
#Some examples
#binary dependent variables
data(airquality)
airquality <- na.omit(airquality)
set.seed(2014)
rf_1 <- randomForest(Ozone ~ ., airquality)
#out of the box partialPlot function
partialPlot(rf_1, airquality, Temp)
partialPlot(rf_1, airquality, Wind)
#plot_partial examples for confidence intervals
plot_partial(rf=rf_1, data=airquality, dv="Ozone", iv="Temp", conf_int_lb=.25, conf_int_ub=.75)
plot_partial(rf=rf_1, data=airquality, dv="Ozone", iv="Wind", conf_int_lb=.25, conf_int_ub=.75, num_sample=100)
plot_partial(rf=rf_1, data=airquality, dv="Ozone", iv="Wind", conf_int_lb=.25, conf_int_ub=.75, num_sample=100, delta=TRUE)
#can automatically determine if it's a classification or regression problem
airquality$Ozone<-ifelse(airquality$Ozone<60, 0, 1)
rf_2 <- randomForest(as.factor(airquality$Ozone) ~ ., airquality)
plot_partial(rf=rf_2, data=airquality, dv="Ozone", iv="Wind", conf_int_lb=.25, conf_int_ub=.75)
#continuos dependent variable examples
data(iris)
rf_iris<-randomForest(Sepal.Length ~., iris)
partialPlot(rf_iris, iris, Species)
plot_partial(rf_iris, iris, dv="Sepal.Length", iv="Species", conf_int_lb=.27, conf_int_ub=.75)
plot_partial(rf_iris, iris, "Sepal.Length", "Species", conf_int_lb=.27, conf_int_ub=.75, delta=TRUE)
#Examples with a larger dataset
data_big<- read.table("http://www.unt.edu/rss/class/Jon/R_SC/Module3/ExampleData3.txt",
header=TRUE, sep="", na.strings="NA", dec=".", strip.white=TRUE)
rf_data<-randomForest(as.factor(marital)~., data=data_big)
ptm <- proc.time()
partialPlot(rf_data, data_big, age)
ptm-proc.time()
ptm <- proc.time()
plot_partial(rf_data, data_big, "marital", "age", conf_int_lb=.45, conf_int_ub=.55, delta=FALSE, num_sample=1000)
ptm-proc.time()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment