Skip to content

Instantly share code, notes, and snippets.

@goldingn
Last active September 13, 2018 03:16
Show Gist options
  • Save goldingn/d5c9ac14546d225bb9066225136f1bb5 to your computer and use it in GitHub Desktop.
Save goldingn/d5c9ac14546d225bb9066225136f1bb5 to your computer and use it in GitHub Desktop.
a prototype interface to defining Gaussian processes on a computational grid efficiently with greta, using the Fast Fourier Transform
# FFT approximation to a GP on a regular grid (defined by a raster)
# information representing a grid of points, defined by the x and y coordinates
# fft_grid <- function (x_coord, y_coord) {
#
# # pre calculate grid info
# dx <- x_coord[2] - x_coord[1]
# dy <- y_coord[2] - y_coord[1]
# m <- length(x_coord)
# n <- length(y_coord)
#
# # expanded grid
# M <- .ceiling2(2 * m)
# N <- .ceiling2(2 * n)
#
# # get the distances from each cell to the centre
# x_sq <- .centred_squared_distance(M, dx)
# y_sq <- .centred_squared_distance(N, dy)
# dist_squared <- outer(x_sq, y_sq, FUN = "+")
# dist <- sqrt(dist_squared)
#
# # represent the centre location in the spectral domain
# adjust <- matrix(0, nrow = M, ncol = N)
# adjust[M/2, N/2] <- 1
# fft_adjust <- (fft(adjust) * M * N)
#
# extract_coords <- expand.grid(seq_len(m), seq_len(n))
#
# # return this info as a list
# list(dim_all = c(M, N),
# dim_sub = c(m, n),
# dist = dist,
# dist_squared = dist_squared,
# fft_adjust = fft_adjust,
# extract_coords = extract_coords)
#
# }
.ceiling2 <- function (n) {
2 ^ ceiling(log2(n))
}
# given a number of cells n, and a cell width diff, return a vector of the
# squared distance from the centre of the vector
.centred_squared_distance <- function (n, diff) {
seq <- seq_len(n) * diff
mid <- diff * n / 2
(seq - mid) ^ 2
}
# given a raster, return the information for an FFT grid
fft_grid <- function (raster) {
# remove NA space around the edges
raster <- trim(raster)
res <- res(raster)
m <- ncol(raster)
n <- nrow(raster)
# expanded grid
M <- .ceiling2(2 * m)
N <- .ceiling2(2 * n)
# get the distances from each cell to the centre
x_sq <- .centred_squared_distance(M, res[1])
y_sq <- .centred_squared_distance(N, res[2])
dist_squared <- outer(x_sq, y_sq, FUN = "+")
dist <- sqrt(dist_squared)
# represent the centre location in the spectral domain
adjust <- matrix(0, nrow = M, ncol = N)
adjust[M/2, N/2] <- 1
fft_adjust <- (fft(adjust) * M * N)
# find the locations of the non-missing cells in the original raster
not_missing_idx <- which(!is.na(getValues(raster)))
# find their locations in the extended grid (using the top left corner)
extract_coords <- cbind(rowFromCell(raster, not_missing_idx),
colFromCell(raster, not_missing_idx))
# return this info as a list
list(dim_all = c(M, N),
dim_sub = c(m, n),
dist = dist,
dist_squared = dist_squared,
fft_adjust = fft_adjust,
extract_coords = extract_coords,
raster = raster)
}
.fft_gp_colour <- function (grid_info, v_all, kernel) {
covar <- kernel(grid_info$dist)
.spectral_colour(covar, v_all, grid_info)
}
.tf_spectral_colour <- function (covar, v, grid_info) {
tf <- tensorflow::tf
original_dim <- dim(covar)[-1]
# cast these to float32s (only type fft accepts)
covar <- tf$cast(covar, tf$complex64)
covar_vec <- greta:::tf_flatten(covar)
v <- tf$cast(v, tf$complex64)
v_vec <- greta:::tf_flatten(v)
# cast covariance and colourless normals into spectral domain
fft_adjust_vec <- tf$cast(as.vector(t(grid_info$fft_adjust)),
tf$complex64)
# add batch dimension
fft_adjust_vec <- greta:::expand_to_batch(tf$expand_dims(fft_adjust_vec, 0L), covar)
covar_fft <- tf$fft(covar_vec) / fft_adjust_vec
v_fft <- tf$fft(v_vec)
# get coloured function in spectral domain
f_fft <- tf$sqrt(covar_fft) * v_fft
# bring coloured function back from the spectral domain
f <- tf$real(tf$ifft(f_fft))
# note that tf$ifft(x) is equivalent to
# stats::fft(x, inverse = TRUE) / length(x)
# so need to re-multiply to get that straight
n_elem <- prod(grid_info$dim_all)
f <- f * n_elem / sqrt(n_elem)
# cast the float type back, reshape and return
f <- tf$cast(f, options()$greta_tf_float)
f <- tf$reshape(f, c(list(-1L), original_dim))
f
}
# given two greta arrays, with matching dimensions, for pointwise covariances
# and standard normal random variables, colour them and return the greta array
# of coloured random variables
.spectral_colour <- function (covar, v, grid_info) {
# check v is a greta array
if (!inherits(v, "greta_array")) {
stop ("'v' must be a greta array",
call. = FALSE)
}
# check they have the same dimensions
if (!identical(dim(covar), dim(v))) {
stop ("'covar' and 'v' must have the same dimensions, ",
"but 'covar' had dimensions ",
paste0(dim(covar), collapse = "x"),
" and 'v' had dimensions ",
paste0(dim(v), collapse = "x"),
call. = FALSE)
}
# create and return an operation greta array for the result
op <- greta::.internals$nodes$constructors$op
out <- op("spectral_colour",
covar, v,
operation_args = list(grid_info = grid_info),
tf_operation = ".tf_spectral_colour")
# hack: assign the function to the operation node's environment, so this can
# be run in parallel
node <- greta:::get_node(out)
node$.__enclos_env__$.tf_spectral_colour <- .tf_spectral_colour
out
}
.fft_gp_extract <- function (grid_info, z) {
z[grid_info$extract_coords]
}
fft_gp <- function (v, kernel, grid_info) {
z_all <- .fft_gp_colour(grid_info, v, kernel)
z <- .fft_gp_extract(grid_info, z_all)
z
}
# # ~~~~~~~
# # simulate GPs on a raster, using the FFT approximation
# library (greta)
# library (raster)
#
# # get a raster on which to simulate a Gaussian process
# raster <- raster(system.file("external/test.grd", package = "raster"))
#
# # covariance hyperparameters and function *of distance*
# # (scaling lengthscale parameter prior by the resolution of the raster)
# l <- lognormal(0, 1) * mean(res(raster))
# sigma <- lognormal(0, 1)
# kernel <- function (dist) {
# sigma * exp(-dist / (l ^ 2))
# }
#
# # set up computational grid from the raster
# grid_info <- fft_grid(raster)
# # create a matrix of standard univariate normals of the correct dimension (the
# # full torus)
# v <- normal(0, 1, dim = grid_info$dim_all)
# # and evaluate the gp at the non-missing cells of the raster
# f <- fft_gp(v, kernel, grid_info)
#
# # sample
# m <- model(f)
# draws <- mcmc(m, warmup = 300, n_samples = 100)
#
# # plot a random posterior sample of the GP by sticking each set of values back
# # in a raster
# raster_plot <- raster
# idx <- which(!is.na(getValues(raster)))
# i <- sample.int(nrow(draws[[1]]), 1)
# raster_plot[idx] <- draws[[1]][i, ]
# plot(raster_plot)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment