Last active
September 13, 2018 03:16
-
-
Save goldingn/d5c9ac14546d225bb9066225136f1bb5 to your computer and use it in GitHub Desktop.
a prototype interface to defining Gaussian processes on a computational grid efficiently with greta, using the Fast Fourier Transform
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FFT approximation to a GP on a regular grid (defined by a raster) | |
# information representing a grid of points, defined by the x and y coordinates | |
# fft_grid <- function (x_coord, y_coord) { | |
# | |
# # pre calculate grid info | |
# dx <- x_coord[2] - x_coord[1] | |
# dy <- y_coord[2] - y_coord[1] | |
# m <- length(x_coord) | |
# n <- length(y_coord) | |
# | |
# # expanded grid | |
# M <- .ceiling2(2 * m) | |
# N <- .ceiling2(2 * n) | |
# | |
# # get the distances from each cell to the centre | |
# x_sq <- .centred_squared_distance(M, dx) | |
# y_sq <- .centred_squared_distance(N, dy) | |
# dist_squared <- outer(x_sq, y_sq, FUN = "+") | |
# dist <- sqrt(dist_squared) | |
# | |
# # represent the centre location in the spectral domain | |
# adjust <- matrix(0, nrow = M, ncol = N) | |
# adjust[M/2, N/2] <- 1 | |
# fft_adjust <- (fft(adjust) * M * N) | |
# | |
# extract_coords <- expand.grid(seq_len(m), seq_len(n)) | |
# | |
# # return this info as a list | |
# list(dim_all = c(M, N), | |
# dim_sub = c(m, n), | |
# dist = dist, | |
# dist_squared = dist_squared, | |
# fft_adjust = fft_adjust, | |
# extract_coords = extract_coords) | |
# | |
# } | |
.ceiling2 <- function (n) { | |
2 ^ ceiling(log2(n)) | |
} | |
# given a number of cells n, and a cell width diff, return a vector of the | |
# squared distance from the centre of the vector | |
.centred_squared_distance <- function (n, diff) { | |
seq <- seq_len(n) * diff | |
mid <- diff * n / 2 | |
(seq - mid) ^ 2 | |
} | |
# given a raster, return the information for an FFT grid | |
fft_grid <- function (raster) { | |
# remove NA space around the edges | |
raster <- trim(raster) | |
res <- res(raster) | |
m <- ncol(raster) | |
n <- nrow(raster) | |
# expanded grid | |
M <- .ceiling2(2 * m) | |
N <- .ceiling2(2 * n) | |
# get the distances from each cell to the centre | |
x_sq <- .centred_squared_distance(M, res[1]) | |
y_sq <- .centred_squared_distance(N, res[2]) | |
dist_squared <- outer(x_sq, y_sq, FUN = "+") | |
dist <- sqrt(dist_squared) | |
# represent the centre location in the spectral domain | |
adjust <- matrix(0, nrow = M, ncol = N) | |
adjust[M/2, N/2] <- 1 | |
fft_adjust <- (fft(adjust) * M * N) | |
# find the locations of the non-missing cells in the original raster | |
not_missing_idx <- which(!is.na(getValues(raster))) | |
# find their locations in the extended grid (using the top left corner) | |
extract_coords <- cbind(rowFromCell(raster, not_missing_idx), | |
colFromCell(raster, not_missing_idx)) | |
# return this info as a list | |
list(dim_all = c(M, N), | |
dim_sub = c(m, n), | |
dist = dist, | |
dist_squared = dist_squared, | |
fft_adjust = fft_adjust, | |
extract_coords = extract_coords, | |
raster = raster) | |
} | |
.fft_gp_colour <- function (grid_info, v_all, kernel) { | |
covar <- kernel(grid_info$dist) | |
.spectral_colour(covar, v_all, grid_info) | |
} | |
.tf_spectral_colour <- function (covar, v, grid_info) { | |
tf <- tensorflow::tf | |
original_dim <- dim(covar)[-1] | |
# cast these to float32s (only type fft accepts) | |
covar <- tf$cast(covar, tf$complex64) | |
covar_vec <- greta:::tf_flatten(covar) | |
v <- tf$cast(v, tf$complex64) | |
v_vec <- greta:::tf_flatten(v) | |
# cast covariance and colourless normals into spectral domain | |
fft_adjust_vec <- tf$cast(as.vector(t(grid_info$fft_adjust)), | |
tf$complex64) | |
# add batch dimension | |
fft_adjust_vec <- greta:::expand_to_batch(tf$expand_dims(fft_adjust_vec, 0L), covar) | |
covar_fft <- tf$fft(covar_vec) / fft_adjust_vec | |
v_fft <- tf$fft(v_vec) | |
# get coloured function in spectral domain | |
f_fft <- tf$sqrt(covar_fft) * v_fft | |
# bring coloured function back from the spectral domain | |
f <- tf$real(tf$ifft(f_fft)) | |
# note that tf$ifft(x) is equivalent to | |
# stats::fft(x, inverse = TRUE) / length(x) | |
# so need to re-multiply to get that straight | |
n_elem <- prod(grid_info$dim_all) | |
f <- f * n_elem / sqrt(n_elem) | |
# cast the float type back, reshape and return | |
f <- tf$cast(f, options()$greta_tf_float) | |
f <- tf$reshape(f, c(list(-1L), original_dim)) | |
f | |
} | |
# given two greta arrays, with matching dimensions, for pointwise covariances | |
# and standard normal random variables, colour them and return the greta array | |
# of coloured random variables | |
.spectral_colour <- function (covar, v, grid_info) { | |
# check v is a greta array | |
if (!inherits(v, "greta_array")) { | |
stop ("'v' must be a greta array", | |
call. = FALSE) | |
} | |
# check they have the same dimensions | |
if (!identical(dim(covar), dim(v))) { | |
stop ("'covar' and 'v' must have the same dimensions, ", | |
"but 'covar' had dimensions ", | |
paste0(dim(covar), collapse = "x"), | |
" and 'v' had dimensions ", | |
paste0(dim(v), collapse = "x"), | |
call. = FALSE) | |
} | |
# create and return an operation greta array for the result | |
op <- greta::.internals$nodes$constructors$op | |
out <- op("spectral_colour", | |
covar, v, | |
operation_args = list(grid_info = grid_info), | |
tf_operation = ".tf_spectral_colour") | |
# hack: assign the function to the operation node's environment, so this can | |
# be run in parallel | |
node <- greta:::get_node(out) | |
node$.__enclos_env__$.tf_spectral_colour <- .tf_spectral_colour | |
out | |
} | |
.fft_gp_extract <- function (grid_info, z) { | |
z[grid_info$extract_coords] | |
} | |
fft_gp <- function (v, kernel, grid_info) { | |
z_all <- .fft_gp_colour(grid_info, v, kernel) | |
z <- .fft_gp_extract(grid_info, z_all) | |
z | |
} | |
# # ~~~~~~~ | |
# # simulate GPs on a raster, using the FFT approximation | |
# library (greta) | |
# library (raster) | |
# | |
# # get a raster on which to simulate a Gaussian process | |
# raster <- raster(system.file("external/test.grd", package = "raster")) | |
# | |
# # covariance hyperparameters and function *of distance* | |
# # (scaling lengthscale parameter prior by the resolution of the raster) | |
# l <- lognormal(0, 1) * mean(res(raster)) | |
# sigma <- lognormal(0, 1) | |
# kernel <- function (dist) { | |
# sigma * exp(-dist / (l ^ 2)) | |
# } | |
# | |
# # set up computational grid from the raster | |
# grid_info <- fft_grid(raster) | |
# # create a matrix of standard univariate normals of the correct dimension (the | |
# # full torus) | |
# v <- normal(0, 1, dim = grid_info$dim_all) | |
# # and evaluate the gp at the non-missing cells of the raster | |
# f <- fft_gp(v, kernel, grid_info) | |
# | |
# # sample | |
# m <- model(f) | |
# draws <- mcmc(m, warmup = 300, n_samples = 100) | |
# | |
# # plot a random posterior sample of the GP by sticking each set of values back | |
# # in a raster | |
# raster_plot <- raster | |
# idx <- which(!is.na(getValues(raster))) | |
# i <- sample.int(nrow(draws[[1]]), 1) | |
# raster_plot[idx] <- draws[[1]][i, ] | |
# plot(raster_plot) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment