Created
October 19, 2012 13:01
-
-
Save glesica/3918128 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# value_iteration.r | |
# George Lesica | |
# CSCI 555 - FA 2012 | |
# Homework 5 | |
# Solution to problem 3 | |
INTENDED <- 0.8 | |
LEFT <- 0.1 | |
RIGHT <- 0.1 | |
DIRECTIONS <- c( | |
'NORTH', | |
'SOUTH', | |
'WEST', | |
'EAST' | |
) | |
maxaction <- function(V, r, c) { | |
# Computes the max value over the possible actions available | |
# from V[r,c]. | |
nrow <- dim(V)[1] | |
ncol <- dim(V)[2] | |
# Compute values for the various possible directions | |
if (r > 1) { | |
north_value <- V[r-1,c] | |
} else { | |
north_value <- V[r,c] | |
} | |
if (r < nrow) { | |
south_value <- V[r+1,c] | |
} else { | |
south_value <- V[r,c] | |
} | |
if (c > 1) { | |
west_value <- V[r,c-1] | |
} else { | |
west_value <- V[r,c] | |
} | |
if (c < ncol) { | |
east_value <- V[r,c+1] | |
} else { | |
east_value <- V[r,c] | |
} | |
# Fix moves that run into a dead cell | |
if (is.na(north_value)) { | |
north_value <- V[r,c] | |
} | |
if (is.na(south_value)) { | |
south_value <- V[r,c] | |
} | |
if (is.na(west_value)) { | |
west_value <- V[r,c] | |
} | |
if (is.na(east_value)) { | |
east_value <- V[r,c] | |
} | |
# Pick the max value based on actual move probabilities | |
act_values <- c( | |
INTENDED * north_value + LEFT * west_value + RIGHT * east_value, | |
INTENDED * south_value + LEFT * east_value + RIGHT * west_value, | |
INTENDED * west_value + LEFT * south_value + RIGHT * north_value, | |
INTENDED * east_value + LEFT * north_value + RIGHT * south_value | |
) | |
# Figure out which direction we should go. | |
max_value <- max(act_values) | |
direction <- DIRECTIONS[which.max(act_values)] | |
return(c(max_value, direction)) | |
} | |
viter <- function(V, R, G, gamma, get.policy=F) { | |
# Performs a single iteration on the given world. | |
nrow <- dim(V)[1] | |
ncol <- dim(V)[2] | |
newV <- matrix(nrow=nrow, ncol=ncol) | |
policy <- matrix(nrow=nrow, ncol=ncol) | |
for (r in 1:nrow) { | |
for (c in 1:ncol) { | |
# Go to the next cell if this one is out-of-bounds | |
if (is.na(V[r,c])) { | |
next | |
} | |
# Skip terminal states | |
if (G[r,c] == TRUE) { | |
newV[r,c] = V[r,c] | |
next | |
} | |
# Find the action that yields the max value | |
m <- maxaction(V, r, c) | |
newV[r,c] <- gamma * as.numeric(m[1]) + R[r,c] | |
policy[r,c] <- m[2] | |
} | |
} | |
if (get.policy) { | |
return(policy) | |
} else { | |
return(newV) | |
} | |
} | |
converge <- function(V, R, G, gamma, epsilon, get.policy=F) { | |
# Performs value iterations until the value estimates converge | |
# to within epsilon. | |
# | |
# Args: | |
# V: Matrix of cell values thus far computed. | |
# R: Matrix of intrisic values for cells. | |
# G: Boolean matrix, true for terminal states | |
# gamma: Discount factor to use. | |
# epsilon: Maximum difference between iterations to accept | |
# | |
# Returns: | |
# A table of cell values or utilities. | |
finished <- FALSE | |
Vcurrent <- V | |
while (! finished) { | |
Vprime <- viter(Vcurrent, R, G, gamma) | |
finished <- all(abs(Vprime - Vcurrent) < epsilon, na.rm = TRUE) | |
Vcurrent <- Vprime | |
} | |
if (get.policy) { | |
return(viter(Vcurrent, R, G, gamma, get.policy=T)) | |
} else { | |
return(Vcurrent) | |
} | |
} | |
policy <- function(V, R, G, gamma, epsilon) { | |
# Computes the policy for a given world. | |
} | |
V <- mat.or.vec(3, 4) | |
V[2,2] <- NA | |
V[1,4] <- 100 | |
V[2,4] <- -100 | |
R <- mat.or.vec(3, 4) | |
R[2,2] <- NA | |
R[1,4] <- 100 | |
R[2,4] <- -100 | |
G <- mat.or.vec(3, 4) == 1 | |
G[1,4] <- TRUE | |
G[2,4] <- TRUE | |
# Find the utilities table | |
values <- converge(V, R, G, 0.9, 0.0001) | |
print(values) | |
# Find the associated policy | |
policy <- converge(V, R, G, 0.9, 0.0001, get.policy=T) | |
print(policy) | |
# Find gamma such that lower right goes south | |
gamma <- 0.9 | |
found <- FALSE | |
while (! found) { | |
policy <- converge(V, R, G, gamma, 0.0001, get.policy=T) | |
if (policy[3,4] == 'SOUTH') { | |
found <- TRUE | |
} else { | |
gamma <- gamma - 0.001 | |
} | |
} | |
print(gamma) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment