Last active
November 19, 2015 16:34
-
-
Save markheckmann/0313362f0c84b21625bd to your computer and use it in GitHub Desktop.
A tweaked partykit::node_barplot function that adds an additional horizontal line
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
node_barplot2 <- function(obj, | |
col = "black", | |
fill = NULL, | |
beside = NULL, | |
ymax = NULL, | |
ylines = NULL, | |
widths = 1, | |
gap = NULL, | |
reverse = NULL, | |
id = TRUE, | |
mainlab = NULL, | |
gp = gpar(), | |
hline = NULL, # draw horizontal line at this value | |
h.gp = gpar()) # style the horizontal line | |
{ | |
## extract response | |
y <- obj$fitted[["(response)"]] | |
stopifnot(is.factor(y) || isTRUE(all.equal(round(y), y)) || is.data.frame(y)) | |
## FIXME: This could be avoided by | |
## predict_party(obj, nodeids(obj, terminal = TRUE), type = "prob") | |
## but only for terminal nodes ^^^^ | |
probs_and_n <- function(x) { | |
y1 <- x$fitted[["(response)"]] | |
if(!is.factor(y1)) { | |
if(is.data.frame(y1)) { | |
y1 <- t(as.matrix(y1)) | |
} else { | |
y1 <- factor(y1, levels = min(y):max(y)) | |
} | |
} | |
w <- x$fitted[["(weights)"]] | |
if(is.null(w)) w <- rep.int(1L, length(y1)) | |
sumw <- if(is.factor(y1)) tapply(w, y1, sum) else drop(y1 %*% w) | |
sumw[is.na(sumw)] <- 0 | |
prob <- c(sumw/sum(w), sum(w)) | |
names(prob) <- c(if(is.factor(y1)) levels(y1) else rownames(y1), "nobs") | |
prob | |
} | |
probs <- do.call("rbind", nodeapply(obj, nodeids(obj), probs_and_n, by_node = FALSE)) | |
nobs <- probs[, "nobs"] | |
probs <- probs[, -ncol(probs), drop = FALSE] | |
if(is.factor(y)) { | |
ylevels <- levels(y) | |
if(is.null(beside)) beside <- if(length(ylevels) < 3L) FALSE else TRUE | |
if(is.null(ymax)) ymax <- if(beside) 1.1 else 1 | |
if(is.null(gap)) gap <- if(beside) 0.1 else 0 | |
} else { | |
if(is.null(beside)) beside <- TRUE | |
if(is.null(ymax)) ymax <- if(beside) max(probs) * 1.1 else max(probs) | |
ylevels <- colnames(probs) | |
if(length(ylevels) < 2) ylevels <- "" | |
if(is.null(gap)) gap <- if(beside) 0.1 else 0 | |
} | |
if(is.null(reverse)) reverse <- !beside | |
if(is.null(fill)) fill <- gray.colors(length(ylevels)) | |
if(is.null(ylines)) ylines <- if(beside) c(3, 2) else c(1.5, 2.5) | |
### panel function for barplots in nodes | |
rval <- function(node) { | |
## id | |
nid <- id_node(node) | |
## parameter setup | |
pred <- probs[nid,] | |
if(reverse) { | |
pred <- rev(pred) | |
ylevels <- rev(ylevels) | |
} | |
np <- length(pred) | |
nc <- if(beside) np else 1 | |
fill <- rep(fill, length.out = np) | |
widths <- rep(widths, length.out = nc) | |
col <- rep(col, length.out = nc) | |
ylines <- rep(ylines, length.out = 2) | |
gap <- gap * sum(widths) | |
yscale <- c(0, ymax) | |
xscale <- c(0, sum(widths) + (nc+1)*gap) | |
## terminal region, i.e. for all terminals | |
top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3, | |
widths = unit(c(ylines[1], 1, ylines[2]), c("lines", "null", "lines")), | |
heights = unit(c(1, 1), c("lines", "null"))), | |
width = unit(1, "npc"), | |
height = unit(1, "npc") - unit(2, "lines"), | |
name = paste0("node_barplot", nid), | |
gp = gp) | |
pushViewport(top_vp) | |
#grid.rect(gp = gpar(fill = "#00ff0030", col = 0)) ### Identitfy region | |
## main title panel | |
top <- viewport(layout.pos.col=2, layout.pos.row=1) | |
pushViewport(top) | |
#grid.rect(gp = gpar(col = "red")) ### Identitfy region | |
# mainlab function that can be passed to this grapcon_generator | |
if (is.null(mainlab)) { | |
mainlab <- if(id) { | |
function(id, nobs) sprintf("Node %s (n = %s)", id, nobs) | |
} else { | |
function(id, nobs) sprintf("n = %s", nobs) | |
} | |
} | |
if (is.function(mainlab)) { | |
mainlab <- mainlab(names(obj)[nid], nobs[nid]) | |
} | |
grid.text(mainlab) | |
popViewport() | |
## Barchart region | |
plot <- viewport(layout.pos.col=2, layout.pos.row=2, | |
xscale=xscale, yscale=yscale, | |
name = paste0("node_barplot", node$nodeID, "plot"), | |
clip = FALSE) | |
pushViewport(plot) | |
# grid.rect(gp = gpar(col = "blue", fill=NA)) ### Identify region | |
if(beside) { | |
xcenter <- cumsum(widths+gap) - widths/2 | |
if(length(xcenter) > 1) grid.xaxis(at = xcenter, label = FALSE) | |
grid.text(ylevels, x = xcenter, y = unit(-1, "lines"), | |
just = c("center", "top"), | |
default.units = "native", check.overlap = TRUE) | |
grid.yaxis() | |
grid.rect(gp = gpar(fill = "transparent")) | |
grid.clip() | |
for (i in 1:np) { | |
grid.rect(x = xcenter[i], y = 0, height = pred[i], | |
width = widths[i], | |
just = c("center", "bottom"), default.units = "native", | |
gp = gpar(col = col[i], fill = fill[i])) | |
} | |
} else { | |
ycenter <- cumsum(pred) - pred | |
if(np > 1) { | |
grid.text(ylevels[1], x = unit(-1, "lines"), y = 0, | |
just = c("left", "center"), rot = 90, | |
default.units = "native", check.overlap = TRUE) | |
grid.text(ylevels[np], x = unit(-1, "lines"), y = ymax, | |
just = c("right", "center"), rot = 90, | |
default.units = "native", check.overlap = TRUE) | |
} | |
if(np > 2) { | |
grid.text(ylevels[-c(1,np)], x = unit(-1, "lines"), y = ycenter[-c(1,np)], | |
just = "center", rot = 90, | |
default.units = "native", check.overlap = TRUE) | |
} | |
grid.yaxis(main = FALSE) | |
grid.clip() | |
grid.rect(gp = gpar(fill = "transparent")) | |
for (i in 1:np) { | |
grid.rect(x = xscale[2]/2, y = ycenter[i], height = min(pred[i], ymax - ycenter[i]), | |
width = widths[1], | |
just = c("center", "bottom"), default.units = "native", | |
gp = gpar(col = col[i], fill = fill[i])) | |
} | |
} | |
grid.rect(gp = gpar(fill = "transparent")) | |
############################################## | |
# NEW: add horizontal line at y-location (NPC) given by hline (numeric) | |
if (!is.null(hline)) { | |
grid.segments( y0 = unit(hline, "npc"), | |
y1 = unit(hline, "npc"), | |
gp= h.gp) | |
} | |
upViewport(2) | |
} | |
return(rval) | |
} | |
class(node_barplot2) <- "grapcon_generator" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment