Skip to content

Instantly share code, notes, and snippets.

@jdthorpe
Last active October 30, 2017 16:50
Show Gist options
  • Save jdthorpe/facfe5957e7e1e3d1ddb795d0ec8977c to your computer and use it in GitHub Desktop.
Save jdthorpe/facfe5957e7e1e3d1ddb795d0ec8977c to your computer and use it in GitHub Desktop.
Hierarchical version of dcast

Introduction to Hierarchical Casting with dhcast()

First, the h in dhcast is for hierarchy, and is useful when you need to (A) cast (go from long to wide data) and where (B) casting should observe a given variable hierarchy, and (optionally) aggregate your data(summarize related records in the long dataset) at the same time.

Lets say, for example, that we want to predict some time varying attribute of a grocery store customers based their produce purchases. Furthermore, we want to aggregate total spending by type of fruit, and also model the effect of certain attributes (e.g. organic v.s. conventional) within each type of fruit. The effect of purchasing organic produce may vary by type of fruit, so we nest the indicator of "organic" within the type of fruit.

The reason that you can't just call model.matrix(~ fruit / is_organic) to create a hierarchical dataset is that (A) customers may purchase both conventional and organic apples in the same week, and (B) customers may make multiple purchases in the same week so you may need to aggregate records as well as cast them.

Note, it's important that you understand data.table's dcast(), as dhcast() is built around dcast() API, and simply adds the / (nesting) and : (interaction) operators to the left hand side of the casting formula.

An example

Note that in this toy dataset, potatoes are never organic and broccoli is always organic. As result we don't need variables for organic within potatoes or broccoli:

items  <-  
as.data.table(read.csv(textConnection(c(
"item,is_organic
apple,FALSE
apple,TRUE
banana,FALSE
banana,TRUE
potato,FALSE
broccoli,TRUE"))))

set.seed(10101L)
DT  <-  cbind(
            as.data.table(
                expand.grid(
                    Customer=LETTERS[1:5],
                    Date=seq.Date(Sys.Date(),length.out=5,by="week"),
                    spending=c(1.32, 3.65, 5.09),
                    stringsAsFactors=FALSE)),
            items[sample(.N,75,replace=TRUE)])[sample(.N,15)]

Now for the real work:

# CREATE THE INTERCEPT FIELDS 
intercepts  <-  
    dhcast(DT,
           Customer + Date ~ item/is_organic,
           value.var="spending")

# CREATE THE VALUE FIELDS 
values  <-  
    dhcast(DT,
           Customer + Date ~ item/is_organic,
           sum,
           value.var="spending")

# MERGE THE TWO DATASETS.  This is possible because the result of dhcast is
# key()'d on the left hand side variables (Customer and Date).
data <- intercepts [ values ]

Inspecting the variable names of the returned intercepts data.table via cat(names(intercepts),sep = "\n"), we see that the interaction term is_organic only appears for apples and banana, as expected

Customer
Date
(intercept)
item=apple:(intercept)
item=banana:(intercept)
item=broccoli:(intercept)
item=potato:(intercept)
item=apple:is_organic=FALSE:(intercept)
item=apple:is_organic=TRUE:(intercept)
item=banana:is_organic=FALSE:(intercept)
item=banana:is_organic=TRUE:(intercept)

and a similar inspection of the data in intercepts shows that interaction banana:organic is only true (1) when banana is also true:

intercepts[,
    .(banana = `item=banana:(intercept)`,
      conventional_banana = `item=banana:is_organic=FALSE:(intercept)`,
      organic_banana = `item=banana:is_organic=TRUE:(intercept)`
      )]

>>     banana conventional_banana organic_banana
>>  1:      1                   1              0
>>  2:      1                   0              1
>>  3:      1                   1              0
>>  4:      0                   0              0
>>  5:      0                   0              0
>>  6:      1                   0              1
>>  7:      0                   0              0
>>  8:      0                   0              0
>>  9:      1                   0              1
>> 10:      1                   1              0
>> 11:      0                   0              0
>> 12:      1                   1              0

Finally, note that the name attributed to the value of the aggregation function defaults to value.var when fun.aggregate is supplied and (intercept) otherwise. This can be overridden via the parameter value.name

dhcast <- function(data,
formula,
value.var,
fun.aggregate=function(x) if(length(x)) 1 else 0,
collapse = ":",
value.name = NULL , # name assigned to the value returend by fun.aggregate(x)
..., # additional arguments to dcast.data.table
verbose=FALSE){
if(!is.data.table(data))
stop("For now, dhcast is only defined for data.tables")
if(missing(value.var))
stop("missing required parameter 'value.var'")
if(!is.character(value.var))
stop("value.var must be a string")
if(missing(value.name)){
value.name <- if(missing(fun.aggregate)) "(intercept)" else value.var
}
# ==============================================
# HELPER FUNCTIONS
# ==============================================
.format <- function(x){
if(is.name(x)) as.character(x)
else if(is.list(x)) do.call("sprintf",
c("%s and %s",
lapply(x,.format)))
else format(x)
}
.expand_model <- function(m,is_toplevel=TRUE){
if(verbose)
.m <- .format(m)
if(!is.call(m)){
return(m)
}else if(is_toplevel && identical(m[[1]],as.name("+"))){
out <- c(.expand_model(m[[2]]),
.expand_model(m[[3]]))
}else if(identical(m[[1]],as.name("/"))){
m[[1]] <- as.name("+")
tmp <- .expand_model(m[[2]],FALSE)
m[[2]] <- if(is.list(tmp)) tmp[[length(tmp)]] else tmp
m[[3]] <- .expand_model(m[[3]],FALSE)
out <- c(tmp, m)
}else if(identical(m[[1]],as.name(":"))){
m[[1]] <- as.name("+")
m[[2]] <- .expand_model(m[[2]],FALSE)
m[[3]] <- .expand_model(m[[3]],FALSE)
out <- m
}else {
stop(sprintf("unknown function / operator '%s'",.format(m[[1]])))
}
if(verbose)
cat(sprintf("%s -> %s\n", .m, paste0(sapply(out,.format),collapse = ' and ')))
return(out)
}
.get_names <- function(s){
if(is.call(s)){
stopifnot(identical(s[[1]],as.name("+")))
return(c(.get_names(s[[2]]),.get_names(s[[3]])))
}else{
return(as.character(s))
}
}
# ==============================================
# ACTUAL WORK
# ==============================================
# CONSTANTS
.fmla <- . ~ .
H <- list()
RHSides <- .expand_model(formula[[3]])
LHS <- formula[[2]]
LHS_names <- .get_names(LHS)
formula[[3]] <- quote(.)
out <- do.call("dcast.data.table",
list(data=data,
formula,
value.var=value.var,
fun.aggregate=substitute(fun.aggregate),
...))
setkeyv(out,LHS_names)
# human-readable variable names
setnames(out, ".",value.name)
for(i in seq_along(RHSides)){
formula[[3]] <- RHSides[[1]]
.fmla[[2]] <- as.call(list(as.name("+"),LHS,RHSides[[i]]))
.LONG <- do.call("dcast.data.table",
list(data=data,
.fmla,
fun.aggregate=substitute(fun.aggregate),
value.var=value.var,
...))
setnames(.LONG,".",value.var)
.fmla[[2]] <- RHSides[[i]]
.levels <- dcast.data.table(.LONG,.fmla,fun.aggregate=length,value.var=value.var)
if(!(nrow(.levels) > 1)){
warning(sprintf("only a single value for factor '%s'; Skipping this factor.",
.format(RHSides[[i]])))
next
}
if(is.call(RHSides[[i]])){
.fmla[[2]] <- RHSides[[i]][[2]]
setnames(.levels,".",".count")
level_counts <- dcast.data.table(.levels,.fmla,fun.aggregate=length,value.var='.count')[`.`>1]
if(!(nrow(.levels) > 1)){
warning(sprintf("only a single value for factor '%s' conditional on previous levels (%s); Skipping this factor.",
.format(RHSides[[i]][[3]]),
.format(RHSides[[i]][[2]])))
next
}
level_counts[,`.`:=NULL]
.LONG <- merge(level_counts ,.LONG)
}
# cast the .LONG data to wide format
formula[[3]] <- RHSides[[i]]
.WIDE <- do.call("dcast.data.table",
list(.LONG,
formula,
fun.aggregate=substitute(fun.aggregate),
value.var=value.var,
...,
sep = "@"))
# create human-readable variable names
nms <- setdiff(names(.WIDE),LHS_names)
RHS_names <- .get_names(RHSides[[i]])
setnames(.WIDE,
nms,
sapply(strsplit(nms,"@"),
function(x)
sprintf("%s:%s",
paste(RHS_names,x,sep="=",collapse = ":"),
value.name)))
# merge the .WIDE in with the main output
setkeyv(.WIDE,LHS_names)
out <- out[.WIDE]
}
out
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment