Created
May 6, 2017 13:11
-
-
Save goldingn/cf9602cec8d140ef526f6e41814e256a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
as_column_array <- function (x) { | |
x <- as.array(x) | |
if (length(dim(x)) == 1) | |
dim(x) <- c(dim(x), 1) | |
x | |
} | |
add_attribute <- function (object, attribute, name) { | |
attr(object, name) <- attribute | |
object | |
} | |
fixed_value <- function (...) { | |
# take a series of named arguments, naming greta arrays and providing | |
# placeholder values for them, and return list ready for use in a with-like | |
# environment | |
# named arrays | |
dots <- list(...) | |
dots_names <- names(dots) | |
# available arrays | |
all <- greta:::all_greta_arrays(env = parent.frame()) | |
all_names <- names(all) | |
# check availability | |
available <- dots_names %in% all_names | |
if (!all(available)) { | |
stop ('could not find greta arrays: ', | |
paste(dots_names[!available]), collapse = ', ') | |
} | |
# get target array list | |
targets <- all[match(dots_names, all_names)] | |
# coerce everything in dots to an array | |
dots <- lapply(dots, as_column_array) | |
# check dimensions | |
dots_dims <- lapply(dots, dim) | |
targets_dims <- lapply(targets, dim) | |
good_dims <- mapply(identical, dots_dims, targets_dims) | |
if (!all(good_dims)) { | |
stop ('the proposed values for the following greta arrays had the wrong dimensions:\n', | |
sprintf('\t%s (need %s, got %s)', | |
dots_names[!good_dims], | |
paste(targets_dims[!good_dims], collapse = ' x '), | |
paste(dots_dims[!good_dims], collapse = ' x '))) | |
} | |
ans <- mapply(add_attribute, | |
targets, | |
dots, | |
MoreArgs = list(name = 'fixed_value'), | |
SIMPLIFY = FALSE) | |
class(ans) <- 'fixed_value_list' | |
invisible(ans) | |
} | |
with.fixed_value_list <- function (data, expr, ...) { | |
# evaluate result | |
result <- eval(expr) | |
# make sure it's a greta array | |
if (!greta:::is.greta_array(result)) { | |
stop ('the second argument to with must be a greta array, instead got a ', | |
class(result)[1]) | |
} | |
# create a dag with the target array | |
dag <- dag_class$new(target_greta_arrays = list(result)) | |
# add the fixed value list to the tf environment | |
dag$add_fixed_value_list(data) | |
# (need to add this method, with something like: | |
# dag$tf_environment$fixed_value_list <- data) | |
# (need to change node$define_tf, so that if that node is in the efixed value | |
# list in the tf environment, it doesn't make its children define themselves, | |
# and it just writes itself in the tensorflow graph as a fixed value) | |
# define the tf graph | |
dag$define_tf() | |
# return the value of the target node | |
value <- dag$trace_values() | |
# assign it to an array | |
target_array <- result$node$value() | |
target_array[] <- value | |
target_array | |
} | |
library(greta) | |
x <- seq(0, 1, length.out = 100) | |
a = normal(0, 1) | |
b = exponential(3) | |
c <- log(a * x ^ b) | |
d <- ilogit(2 * c) | |
# with(fixed_value(a = 2, b = 1), c) | |
# with(fixed_value(c = -2), d) | |
# this needs three small changes before it can be implemented: | |
# 1. add a method: 'add_fixed_value_list' to dag_class which does this: | |
# function (fixed_value_list) | |
# self$tf_environment$fixed_value_list <- fixed_value_list | |
# 2. change the dag_class 'build_dag' method to ignore the children of nodes | |
# that are in the fixed value list | |
# 3. change node_class 'define_tf' method, to look in the tf environment for a | |
# fixed value list, and if there is one to look for its own name. If it is | |
# there, it should not make its children define themselves, and instead just | |
# write itself in the tensorflow graph with the fixed value stated in the list | |
# (converting from an array to a tensor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment