Skip to content

Instantly share code, notes, and snippets.

@goldingn
Created May 6, 2017 13:11
Show Gist options
  • Save goldingn/cf9602cec8d140ef526f6e41814e256a to your computer and use it in GitHub Desktop.
Save goldingn/cf9602cec8d140ef526f6e41814e256a to your computer and use it in GitHub Desktop.
as_column_array <- function (x) {
x <- as.array(x)
if (length(dim(x)) == 1)
dim(x) <- c(dim(x), 1)
x
}
add_attribute <- function (object, attribute, name) {
attr(object, name) <- attribute
object
}
fixed_value <- function (...) {
# take a series of named arguments, naming greta arrays and providing
# placeholder values for them, and return list ready for use in a with-like
# environment
# named arrays
dots <- list(...)
dots_names <- names(dots)
# available arrays
all <- greta:::all_greta_arrays(env = parent.frame())
all_names <- names(all)
# check availability
available <- dots_names %in% all_names
if (!all(available)) {
stop ('could not find greta arrays: ',
paste(dots_names[!available]), collapse = ', ')
}
# get target array list
targets <- all[match(dots_names, all_names)]
# coerce everything in dots to an array
dots <- lapply(dots, as_column_array)
# check dimensions
dots_dims <- lapply(dots, dim)
targets_dims <- lapply(targets, dim)
good_dims <- mapply(identical, dots_dims, targets_dims)
if (!all(good_dims)) {
stop ('the proposed values for the following greta arrays had the wrong dimensions:\n',
sprintf('\t%s (need %s, got %s)',
dots_names[!good_dims],
paste(targets_dims[!good_dims], collapse = ' x '),
paste(dots_dims[!good_dims], collapse = ' x ')))
}
ans <- mapply(add_attribute,
targets,
dots,
MoreArgs = list(name = 'fixed_value'),
SIMPLIFY = FALSE)
class(ans) <- 'fixed_value_list'
invisible(ans)
}
with.fixed_value_list <- function (data, expr, ...) {
# evaluate result
result <- eval(expr)
# make sure it's a greta array
if (!greta:::is.greta_array(result)) {
stop ('the second argument to with must be a greta array, instead got a ',
class(result)[1])
}
# create a dag with the target array
dag <- dag_class$new(target_greta_arrays = list(result))
# add the fixed value list to the tf environment
dag$add_fixed_value_list(data)
# (need to add this method, with something like:
# dag$tf_environment$fixed_value_list <- data)
# (need to change node$define_tf, so that if that node is in the efixed value
# list in the tf environment, it doesn't make its children define themselves,
# and it just writes itself in the tensorflow graph as a fixed value)
# define the tf graph
dag$define_tf()
# return the value of the target node
value <- dag$trace_values()
# assign it to an array
target_array <- result$node$value()
target_array[] <- value
target_array
}
library(greta)
x <- seq(0, 1, length.out = 100)
a = normal(0, 1)
b = exponential(3)
c <- log(a * x ^ b)
d <- ilogit(2 * c)
# with(fixed_value(a = 2, b = 1), c)
# with(fixed_value(c = -2), d)
# this needs three small changes before it can be implemented:
# 1. add a method: 'add_fixed_value_list' to dag_class which does this:
# function (fixed_value_list)
# self$tf_environment$fixed_value_list <- fixed_value_list
# 2. change the dag_class 'build_dag' method to ignore the children of nodes
# that are in the fixed value list
# 3. change node_class 'define_tf' method, to look in the tf environment for a
# fixed value list, and if there is one to look for its own name. If it is
# there, it should not make its children define themselves, and instead just
# write itself in the tensorflow graph with the fixed value stated in the list
# (converting from an array to a tensor)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment