Last active
February 20, 2018 05:37
-
-
Save goldingn/79859cbdd4496f64015fb9c21244d259 to your computer and use it in GitHub Desktop.
hack greta v0.2.4 to use tensorflow HMC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# get greta working with bayesflow's HMC implementation & working via | |
# tensorflow's run syntax | |
build_function <- function (dag) { | |
# temporarily pass float type info to options, so it can be accessed by | |
# nodes on definition, without clunky explicit passing | |
old_float_type <- options()$greta_tf_float | |
on.exit(options(greta_tf_float = old_float_type)) | |
options(greta_tf_float = dag$tf_float) | |
# define all nodes | |
dag$on_graph(lapply(dag$node_list, | |
function (x) x$define_tf(dag))) | |
# define an overall log density with adjustment | |
dag$on_graph(dag$define_joint_density()) | |
} | |
foo <- function (free_state) { | |
# flush the tf environment | |
vals <- ls(dag$tf_environment) | |
vals <- vals[!vals %in% c("config", "sess", "n_cores")] | |
rm(list = vals, envir = dag$tf_environment) | |
# split up the free state into free state variables | |
nelem <- dim(free_state)[1] | |
params <- dag$parameters_example | |
lengths <- vapply(params, | |
function (x) as.integer(prod(dim(x))), | |
FUN.VALUE = 1L) | |
args <- dag$on_graph(tf$split(free_state, lengths)) | |
# put these tensors in the tf environment, with the correct names | |
names <- paste0(names(params), "_free") | |
for (i in seq_along(names)) | |
assign(names[i], args[[i]], envir = dag$tf_environment) | |
# define the functions in the tf* environment, except for the free state | |
# variables | |
build_function(dag) | |
# return the log density | |
dag$tf_environment$joint_density_adj | |
} | |
# temporary hack: | |
# Hack to replace the tf() method in variable nodes. | |
# In full implementation, add an option to skip defining the free state variable | |
replace_tf <- function (variable_node) { | |
# replace the tf function | |
new_tf <- function (dag) { | |
tf_name <- dag$tf_name(self) | |
free_name <- sprintf('%s_free', tf_name) | |
# get the log jacobian adjustment for the free state | |
tf_adj <- self$tf_adjustment(dag) | |
adj_name <- sprintf('%s_adj', tf_name) | |
assign(adj_name, | |
tf_adj, | |
envir = dag$tf_environment) | |
# map from the free to constrained state in a new tensor | |
tf_free <- get(free_name, envir = dag$tf_environment) | |
node <- self$tf_from_free(tf_free, dag$tf_environment) | |
assign(tf_name, | |
node, | |
envir = dag$tf_environment) | |
} | |
environment(new_tf) <- variable_node$.__enclos_env__ | |
unlockBinding("tf", variable_node) | |
variable_node$tf <- new_tf | |
lockBinding("tf", variable_node) | |
invisible (NULL) | |
} | |
# /temporary hack | |
run_mcmc <- function (i = 1, targets) { | |
names <- names(targets) | |
for (i in seq_along(names)) | |
assign(names[i], targets[[i]]) | |
# build up the call as text :| | |
call_text <- sprintf("greta::model(%s, n_cores = 1)", | |
paste(names, collapse = ", ")) | |
model <- eval(parse(text = call_text)) | |
dag <- model$dag | |
# make sure this dag is in scope of foo | |
environment(foo)$dag <- dag | |
# modify nodes so that variable nodes don't define their free states | |
# (temporary hack) | |
nodes <- dag$node_list | |
which_are_variables <- dag$node_types == "variable" | |
lapply(nodes[which_are_variables], replace_tf) | |
# create the free state | |
nelem <- length(dag$example_parameters()) | |
free_state <- dag$on_graph(tf$zeros(nelem)) | |
# variable step size to be tuned externally | |
ss <- dag$on_graph(tf$constant(0.05)) | |
lf <- dag$on_graph(tf$constant(15L)) | |
draws_tensor <- dag$on_graph( | |
tf$contrib$bayesflow$hmc$chain(10000, ss, lf, | |
free_state, | |
foo, | |
event_dims = 0L) | |
) | |
dag$tf_run(sess$run(tf$global_variables_initializer())) | |
draws <- dag$tf_environment$sess$run(draws_tensor, | |
feed_dict = dict(ss = 0.05, lf = 15L)) | |
draws | |
} | |
library(tensorflow) | |
library (greta) | |
# define a model | |
z <- normal(0, 1, 2) | |
x <- normal(0, 1, 3) | |
model <- model(x, z, n_cores = 1) | |
targets <- model$target_greta_arrays | |
# run multiple mcmc chains in parallel, each with tensorflow HMC | |
library (snowfall) | |
sfInit(parallel = TRUE, cpus = 2) | |
sfLibrary(greta) | |
sfLibrary(tensorflow) | |
sfExport("foo", "replace_tf", "build_function") | |
# 2x 10,000 iterations in ~ 18 seconds | |
system.time(draws <- sfLapply(1:2, run_mcmc, model$target_greta_arrays)) | |
plot(draws[[1]][[1]][, 4], type = "l") | |
lines(draws[[2]][[1]][, 4], col = "light blue") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment