Skip to content

Instantly share code, notes, and snippets.

@markheckmann
Last active November 19, 2015 16:34
Show Gist options
  • Save markheckmann/0313362f0c84b21625bd to your computer and use it in GitHub Desktop.
Save markheckmann/0313362f0c84b21625bd to your computer and use it in GitHub Desktop.
A tweaked partykit::node_barplot function that adds an additional horizontal line
node_barplot2 <- function(obj,
col = "black",
fill = NULL,
beside = NULL,
ymax = NULL,
ylines = NULL,
widths = 1,
gap = NULL,
reverse = NULL,
id = TRUE,
mainlab = NULL,
gp = gpar(),
hline = NULL, # draw horizontal line at this value
h.gp = gpar()) # style the horizontal line
{
## extract response
y <- obj$fitted[["(response)"]]
stopifnot(is.factor(y) || isTRUE(all.equal(round(y), y)) || is.data.frame(y))
## FIXME: This could be avoided by
## predict_party(obj, nodeids(obj, terminal = TRUE), type = "prob")
## but only for terminal nodes ^^^^
probs_and_n <- function(x) {
y1 <- x$fitted[["(response)"]]
if(!is.factor(y1)) {
if(is.data.frame(y1)) {
y1 <- t(as.matrix(y1))
} else {
y1 <- factor(y1, levels = min(y):max(y))
}
}
w <- x$fitted[["(weights)"]]
if(is.null(w)) w <- rep.int(1L, length(y1))
sumw <- if(is.factor(y1)) tapply(w, y1, sum) else drop(y1 %*% w)
sumw[is.na(sumw)] <- 0
prob <- c(sumw/sum(w), sum(w))
names(prob) <- c(if(is.factor(y1)) levels(y1) else rownames(y1), "nobs")
prob
}
probs <- do.call("rbind", nodeapply(obj, nodeids(obj), probs_and_n, by_node = FALSE))
nobs <- probs[, "nobs"]
probs <- probs[, -ncol(probs), drop = FALSE]
if(is.factor(y)) {
ylevels <- levels(y)
if(is.null(beside)) beside <- if(length(ylevels) < 3L) FALSE else TRUE
if(is.null(ymax)) ymax <- if(beside) 1.1 else 1
if(is.null(gap)) gap <- if(beside) 0.1 else 0
} else {
if(is.null(beside)) beside <- TRUE
if(is.null(ymax)) ymax <- if(beside) max(probs) * 1.1 else max(probs)
ylevels <- colnames(probs)
if(length(ylevels) < 2) ylevels <- ""
if(is.null(gap)) gap <- if(beside) 0.1 else 0
}
if(is.null(reverse)) reverse <- !beside
if(is.null(fill)) fill <- gray.colors(length(ylevels))
if(is.null(ylines)) ylines <- if(beside) c(3, 2) else c(1.5, 2.5)
### panel function for barplots in nodes
rval <- function(node) {
## id
nid <- id_node(node)
## parameter setup
pred <- probs[nid,]
if(reverse) {
pred <- rev(pred)
ylevels <- rev(ylevels)
}
np <- length(pred)
nc <- if(beside) np else 1
fill <- rep(fill, length.out = np)
widths <- rep(widths, length.out = nc)
col <- rep(col, length.out = nc)
ylines <- rep(ylines, length.out = 2)
gap <- gap * sum(widths)
yscale <- c(0, ymax)
xscale <- c(0, sum(widths) + (nc+1)*gap)
## terminal region, i.e. for all terminals
top_vp <- viewport(layout = grid.layout(nrow = 2, ncol = 3,
widths = unit(c(ylines[1], 1, ylines[2]), c("lines", "null", "lines")),
heights = unit(c(1, 1), c("lines", "null"))),
width = unit(1, "npc"),
height = unit(1, "npc") - unit(2, "lines"),
name = paste0("node_barplot", nid),
gp = gp)
pushViewport(top_vp)
#grid.rect(gp = gpar(fill = "#00ff0030", col = 0)) ### Identitfy region
## main title panel
top <- viewport(layout.pos.col=2, layout.pos.row=1)
pushViewport(top)
#grid.rect(gp = gpar(col = "red")) ### Identitfy region
# mainlab function that can be passed to this grapcon_generator
if (is.null(mainlab)) {
mainlab <- if(id) {
function(id, nobs) sprintf("Node %s (n = %s)", id, nobs)
} else {
function(id, nobs) sprintf("n = %s", nobs)
}
}
if (is.function(mainlab)) {
mainlab <- mainlab(names(obj)[nid], nobs[nid])
}
grid.text(mainlab)
popViewport()
## Barchart region
plot <- viewport(layout.pos.col=2, layout.pos.row=2,
xscale=xscale, yscale=yscale,
name = paste0("node_barplot", node$nodeID, "plot"),
clip = FALSE)
pushViewport(plot)
# grid.rect(gp = gpar(col = "blue", fill=NA)) ### Identify region
if(beside) {
xcenter <- cumsum(widths+gap) - widths/2
if(length(xcenter) > 1) grid.xaxis(at = xcenter, label = FALSE)
grid.text(ylevels, x = xcenter, y = unit(-1, "lines"),
just = c("center", "top"),
default.units = "native", check.overlap = TRUE)
grid.yaxis()
grid.rect(gp = gpar(fill = "transparent"))
grid.clip()
for (i in 1:np) {
grid.rect(x = xcenter[i], y = 0, height = pred[i],
width = widths[i],
just = c("center", "bottom"), default.units = "native",
gp = gpar(col = col[i], fill = fill[i]))
}
} else {
ycenter <- cumsum(pred) - pred
if(np > 1) {
grid.text(ylevels[1], x = unit(-1, "lines"), y = 0,
just = c("left", "center"), rot = 90,
default.units = "native", check.overlap = TRUE)
grid.text(ylevels[np], x = unit(-1, "lines"), y = ymax,
just = c("right", "center"), rot = 90,
default.units = "native", check.overlap = TRUE)
}
if(np > 2) {
grid.text(ylevels[-c(1,np)], x = unit(-1, "lines"), y = ycenter[-c(1,np)],
just = "center", rot = 90,
default.units = "native", check.overlap = TRUE)
}
grid.yaxis(main = FALSE)
grid.clip()
grid.rect(gp = gpar(fill = "transparent"))
for (i in 1:np) {
grid.rect(x = xscale[2]/2, y = ycenter[i], height = min(pred[i], ymax - ycenter[i]),
width = widths[1],
just = c("center", "bottom"), default.units = "native",
gp = gpar(col = col[i], fill = fill[i]))
}
}
grid.rect(gp = gpar(fill = "transparent"))
##############################################
# NEW: add horizontal line at y-location (NPC) given by hline (numeric)
if (!is.null(hline)) {
grid.segments( y0 = unit(hline, "npc"),
y1 = unit(hline, "npc"),
gp= h.gp)
}
upViewport(2)
}
return(rval)
}
class(node_barplot2) <- "grapcon_generator"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment