Created
May 11, 2017 05:22
-
-
Save goldingn/005f52ac488ff9796e82df4d9dcb4514 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# devtools::install_github('goldingn/greta') | |
devtools::load_all() | |
fixed_values <- function (...) { | |
# get the values and their names | |
values <- list(...) | |
names <- names(values) | |
stopifnot(length(names) == length(values)) | |
# get the corresponding greta arrays | |
target_greta_arrays <- lapply(names(values), | |
get, | |
envir = parent.frame()) | |
# make sure that's what they are | |
are_greta_arrays <- vapply(target_greta_arrays, | |
is.greta_array, | |
FUN.VALUE = FALSE) | |
stopifnot(all(are_greta_arrays)) | |
# make sure the values have the correct dimensions | |
assign_dim <- function (value, greta_array) { | |
array <- strip_unknown_class(greta_array$node$value()) | |
array[] <- value | |
array | |
} | |
values <- mapply(assign_dim, values, target_greta_arrays) | |
ans <- list(values = values, | |
target_greta_arrays = target_greta_arrays) | |
class(ans) <- 'fixed_values' | |
ans | |
} | |
with.fixed_values <- function (data, expr, ...) { | |
# get the target | |
target <- eval(expr) | |
stopifnot(is.greta_array(target)) | |
target_greta_arrays <- c(data$target_greta_arrays, list(target)) | |
# build a dag containing all of these, and define the TF graph | |
dag <- dag_class$new(target_greta_arrays) | |
check_tf_version("error") | |
dag$define_tf() | |
# build and send a dict for the fixed values | |
node_values <- data$values | |
names(node_values) <- vapply(data$target_greta_arrays, | |
member, | |
FUN.VALUE = '', | |
'node$name') | |
# send it to tf | |
assign("with_list", node_values, envir = dag$tf_environment) | |
ex <- expression(with_dict <- do.call(dict, with_list)) | |
eval(ex, envir = dag$tf_environment) | |
# evaluate the target there | |
ex <- sprintf("sess$run(%s, feed_dict = with_dict)", | |
target$node$name) | |
result <- eval(parse(text = ex), | |
envir = dag$tf_environment) | |
# add tryCatch here^ for a bad dict! | |
result | |
} | |
iris$y <- iris$Species == 'setosa' | |
a = normal(0, 1) | |
b = normal(0, 3, dim = 4) | |
mu <- a + iris[, 1:4] %*% b | |
p <- ilogit(mu) | |
distribution(iris$y) = bernoulli(p) | |
p_eval <- with(fixed_values(a = 0, b = 0.2), p) | |
p_eval <- with(fixed_values(mu = 0, a = 1), p) | |
summary(p_eval[, 1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment