Skip to content

Instantly share code, notes, and snippets.

@goldingn
Created May 11, 2017 05:22
Show Gist options
  • Save goldingn/005f52ac488ff9796e82df4d9dcb4514 to your computer and use it in GitHub Desktop.
Save goldingn/005f52ac488ff9796e82df4d9dcb4514 to your computer and use it in GitHub Desktop.
# devtools::install_github('goldingn/greta')
devtools::load_all()
fixed_values <- function (...) {
# get the values and their names
values <- list(...)
names <- names(values)
stopifnot(length(names) == length(values))
# get the corresponding greta arrays
target_greta_arrays <- lapply(names(values),
get,
envir = parent.frame())
# make sure that's what they are
are_greta_arrays <- vapply(target_greta_arrays,
is.greta_array,
FUN.VALUE = FALSE)
stopifnot(all(are_greta_arrays))
# make sure the values have the correct dimensions
assign_dim <- function (value, greta_array) {
array <- strip_unknown_class(greta_array$node$value())
array[] <- value
array
}
values <- mapply(assign_dim, values, target_greta_arrays)
ans <- list(values = values,
target_greta_arrays = target_greta_arrays)
class(ans) <- 'fixed_values'
ans
}
with.fixed_values <- function (data, expr, ...) {
# get the target
target <- eval(expr)
stopifnot(is.greta_array(target))
target_greta_arrays <- c(data$target_greta_arrays, list(target))
# build a dag containing all of these, and define the TF graph
dag <- dag_class$new(target_greta_arrays)
check_tf_version("error")
dag$define_tf()
# build and send a dict for the fixed values
node_values <- data$values
names(node_values) <- vapply(data$target_greta_arrays,
member,
FUN.VALUE = '',
'node$name')
# send it to tf
assign("with_list", node_values, envir = dag$tf_environment)
ex <- expression(with_dict <- do.call(dict, with_list))
eval(ex, envir = dag$tf_environment)
# evaluate the target there
ex <- sprintf("sess$run(%s, feed_dict = with_dict)",
target$node$name)
result <- eval(parse(text = ex),
envir = dag$tf_environment)
# add tryCatch here^ for a bad dict!
result
}
iris$y <- iris$Species == 'setosa'
a = normal(0, 1)
b = normal(0, 3, dim = 4)
mu <- a + iris[, 1:4] %*% b
p <- ilogit(mu)
distribution(iris$y) = bernoulli(p)
p_eval <- with(fixed_values(a = 0, b = 0.2), p)
p_eval <- with(fixed_values(mu = 0, a = 1), p)
summary(p_eval[, 1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment