Last active
April 25, 2016 23:13
-
-
Save jameskyle/402e7a4f60d578e62d99 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(MDPtoolbox) | |
library(Matrix) | |
library(ggplot2) | |
library(grid) | |
library(gridExtra) | |
library(doMC) | |
cores <- detectCores() - (detectCores() / 2) / 2 | |
registerDoMC(cores=cores) | |
set.seed(1234) | |
mdp_value_iteration <- function (P, R, discount, epsilon, max_iter, V0) | |
{ | |
start <- as.POSIXlt(Sys.time()) | |
if (discount <= 0 | discount > 1) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() > 3 & ifelse(!missing(epsilon), ifelse(epsilon < | |
0, T, F), F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: epsilon must be upper than 0") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() > 4 & ifelse(!missing(max_iter), ifelse(max_iter <= | |
0, T, F), F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: The maximum number of iteration must be upper than 0") | |
print("--------------------------------------------------------") | |
} | |
else if (is.list(P) & nargs() > 5 & ifelse(!missing(V0), | |
ifelse(length(V0) != dim(P[[1]])[1], T, F), F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: V0 must have the same dimension as P") | |
print("--------------------------------------------------------") | |
} | |
else if (!is.list(P) & nargs() > 5 & ifelse(!missing(V0), | |
ifelse(length(V0) != dim(P)[1], T, F), F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: V0 must have the same dimension as P") | |
print("--------------------------------------------------------") | |
} | |
else { | |
if (discount == 1) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox WARNING: check conditions of convergence.") | |
print("With no discount, convergence is not always assumed.") | |
print("--------------------------------------------------------") | |
} | |
if (is.list(P)) { | |
S <- dim(P[[1]])[1] | |
A <- length(P) | |
} | |
else { | |
S <- dim(P)[1] | |
A <- dim(P)[3] | |
} | |
PR <- mdp_computePR(P, R) | |
if (nargs() < 6) { | |
V0 <- numeric(S) | |
} | |
if (nargs() < 4) { | |
epsilon <- 0.01 | |
} | |
if (discount != 1) | |
computed_max_iter <- 5000 | |
if (nargs() < 5) { | |
if (discount != 1) { | |
max_iter <- computed_max_iter | |
} | |
else { | |
max_iter <- 5000 | |
} | |
} | |
else { | |
if (discount != 1 & max_iter > computed_max_iter) { | |
print(paste("MDP Toolbox WARNING: max_iter is bounded by ", | |
computed_max_iter)) | |
max_iter <- computed_max_iter | |
} | |
} | |
if (discount != 1) { | |
thresh <- epsilon * (1 - discount)/discount | |
} | |
else { | |
thresh <- epsilon | |
} | |
iter <- 0 | |
V <- V0 | |
is_done <- F | |
converged <- -1 | |
while (!is_done) { | |
iter <- iter + 1 | |
Vprev <- V | |
bellman <- mdp_bellman_operator(P, PR, discount, | |
V) | |
V <- bellman[[1]] | |
policy <- bellman[[2]] | |
variation <- mdp_span(V - Vprev) | |
if (variation < thresh) { | |
# is_done <- T | |
converged <- iter | |
#print(sprintf("MDP Toolbox: epsilon-optimal policy found at iter %d", converged)) | |
} | |
if (iter == max_iter) { | |
is_done <- T | |
#print("MDP Toolbox: iterations stopped by maximum number of iteration condition") | |
} | |
} | |
} | |
end <- as.POSIXlt(Sys.time()) | |
list(V = V, | |
policy = policy, | |
iter = iter, | |
time = as.numeric(end - start), | |
epsilon = epsilon, | |
discount = discount, | |
converged = converged) | |
} | |
mdp_Q_learning <- function (P, R, discount, N, max.time=1800) | |
{ | |
# ganked from MDPtoolbox | |
start <- as.POSIXlt(Sys.time()) | |
if (discount <= 0 | discount > 1) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() >= 4 & ifelse(!missing(N), N <= 0, F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: N must a positive integer") | |
print("--------------------------------------------------------") | |
} | |
else { | |
if (nargs() < 4) { | |
N <- 10000 | |
#N <- 1000 | |
} | |
if (is.list(P)) { | |
S <- dim(P[[1]])[1] | |
A <- length(P) | |
} | |
else { | |
S <- dim(P)[1] | |
A <- dim(P)[3] | |
} | |
Q <- matrix(0, S, A) | |
dQ <- matrix(0, S, A) | |
mean_discrepancy <- NULL | |
discrepancy <- NULL | |
max.time.exceeded <- F | |
state <- sample(1:S, 1, replace = T) | |
iters <- 1 | |
for (n in 1:N) { | |
t <- as.numeric(as.POSIXlt(Sys.time()) - start, units="secs") | |
if (t > max.time) { | |
message(sprintf("Q Learning: Max time of %d exceeded", max.time)) | |
max.time.exceeded <- T | |
break | |
} else { | |
delta <- max.time - t | |
message(sprintf("Iteration %d: %f seconds till max.time", iters, delta)) | |
} | |
if (n%%100 == 0) { | |
state <- sample(1:S, 1, replace = T) | |
} | |
pn <- runif(1) | |
if (pn < (1 - (1/log(n + 2)))) { | |
optimal_action <- max(Q[state, ]) | |
a <- which.max(Q[state, ]) | |
} | |
else { | |
a <- sample(1:A, 1, replace = T) | |
} | |
p_s_new <- runif(1) | |
p <- 0 | |
s_new <- 0 | |
while ((p < p_s_new) & (s_new < S)) { | |
s_new <- s_new + 1 | |
if (is.list(P)) { | |
p <- p + P[[a]][state, s_new] | |
} | |
else { | |
p <- p + P[state, s_new, a] | |
} | |
} | |
if (is.list(R)) { | |
r <- R[[a]][state, s_new] | |
} | |
else { | |
if (length(dim(R)) == 3) { | |
r <- R[state, s_new, a] | |
} | |
else { | |
r <- R[state, a] | |
} | |
} | |
delta <- r + discount * max(Q[s_new, ]) - Q[state, a] | |
dQ <- (1/sqrt(n + 2)) * delta | |
Q[state, a] <- Q[state, a] + dQ | |
state <- s_new | |
discrepancy[(n%%100) + 1] = abs(dQ) | |
if (length(discrepancy) == 100) { | |
mean_discrepancy <- c(mean_discrepancy, mean(discrepancy)) | |
discrepancy <- NULL | |
} | |
iters <- n | |
} | |
V <- apply(Q, 1, max) | |
policy <- apply(Q, 1, which.max) | |
} | |
end <- as.POSIXlt(Sys.time()) | |
return(list( | |
Q = Q, | |
V = V, | |
policy = policy, | |
mean_discrepancy = mean_discrepancy, | |
discount = discount, | |
iter=iters, | |
max.iter=N, | |
time=as.numeric(end - start), | |
max.time=max.time.exceeded | |
)) | |
} | |
mdp_policy_iteration <- function (P, R, discount, max_iter, policy0, eval_type) | |
{ | |
# Modified from MDPtoolbox package | |
start <- as.POSIXlt(Sys.time()) | |
if (discount <= 0 | discount > 1) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: Discount rate must be in ]0; 1]") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() > 3 & is.list(P) & ifelse(!missing(policy0), | |
length(policy0) != dim(P[[1]])[1], F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: policy must have the same dimension as P") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() > 3 & !is.list(P) & ifelse(!missing(policy0), | |
length(policy0) != dim(P)[1], F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: policy must have the same dimension as P") | |
print("--------------------------------------------------------") | |
} | |
else if (nargs() > 4 & ifelse(!missing(max_iter), max_iter <= 0, F)) { | |
print("--------------------------------------------------------") | |
print("MDP Toolbox ERROR: The maximum number of iteration must be geater than 0") | |
print("--------------------------------------------------------") | |
} | |
else { | |
if (is.list(P)) { | |
S <- dim(P[[1]])[1] | |
A <- length(P) | |
} | |
else { | |
S <- dim(P)[1] | |
A <- dim(P)[3] | |
} | |
PR <- mdp_computePR(P, R) | |
if (nargs() < 6) { | |
eval_type <- 0 | |
} | |
if (nargs() < 5) { | |
bellman <- mdp_bellman_operator(P, PR, discount, | |
numeric(S)) | |
Vunused <- bellman[[1]] | |
policy0 <- bellman[[2]] | |
} | |
if (nargs() < 4) { | |
max_iter <- 1000 | |
} | |
iter <- 0 | |
policy <- policy0 | |
is_done <- F | |
while (!is_done) { | |
iter <- iter + 1 | |
if (eval_type == 0) { | |
V <- mdp_eval_policy_matrix(P, PR, discount, | |
policy) | |
} | |
else { | |
V <- mdp_eval_policy_iterative(P, PR, discount, | |
policy) | |
} | |
bellman <- mdp_bellman_operator(P, PR, discount, | |
V) | |
Vnext <- bellman[[1]] | |
policy_next <- bellman[[2]] | |
n_different <- sum(policy_next != policy) | |
#if (setequal(policy_next, policy) | iter == max_iter) { | |
if (iter == max_iter) { | |
is_done <- T | |
} | |
else { | |
policy <- policy_next | |
} | |
} | |
end <- as.POSIXlt(Sys.time()) | |
return(list(V = V, | |
policy = policy, | |
iter = iter, | |
time = as.numeric(end - start), | |
discount=discount)) | |
} | |
} | |
forest <- function() { | |
# Wait | |
m1 <- matrix(c( | |
# Cleared, Young Forest, Old Forest, Farm | |
0.1, 0.9, 0.0, 0.0, # Cleared | |
0.1, 0.0, 0.9, 0.0, # Young Forest | |
0.1, 0.0, 0.9, 0.0, # Old Forest | |
0.1, 0.0, 0.0, 0.9 # Farm | |
), 4, 4, byrow=T) | |
# Cut | |
m2 <- matrix(c( | |
# Cleared, Young Forest, Old Forest, Farm | |
1, 0, 0, 0, # Cleared | |
1, 0, 0, 0, # Young Forest | |
1, 0, 0, 0, # Old Forest | |
1, 0, 0, 0 # Farm | |
), 4, 4, byrow=T) | |
# Cultivate | |
m3 <- matrix(c( | |
# Cleared, Young Forest, Old Forest, Farm | |
0.1, 0, 0, 0.9, # Cleared | |
1, 0, 0, 0, # Young Forest | |
1, 0, 0, 0, # Old Forest | |
0.1, 0, 0, 0.9 # Farm | |
), 4, 4, byrow=T) | |
P <- array(0, dim=c(4,4,3)) | |
P[,,1] <- m1 | |
P[,,2] <- m2 | |
P[,,3] <- m3 | |
R <- matrix(c( | |
# Rewards | |
# Waiting, Cutting, Cultivating | |
0.0, 0.0, 0.0, # empty field | |
2.0, 1.0, 1.0, # Young Forest | |
4.0, 2.0, 2.0, # Old Forest | |
2.0, 0.0, 4.0 # Farm | |
), 4, 3, byrow=T) | |
colnames(R) <- c('R1', 'R2', 'R3') | |
colnames(P) <- c('Fire', 'Young Forest', 'Old Forest', 'Farm') | |
list(P=P, R=R) | |
} | |
forest.calc <- function() { | |
message("Collecting forest management calculations...") | |
f.data <- forest() | |
values <- value.iteration(f.data) | |
save(values, file="forest_values.RData") | |
policies <- policy.iteration(f.data) | |
save(policies, file="forest_policies.RData") | |
qlearning <- q.learning(f.data) | |
save(qlearning, file="forest_qlearning.RData") | |
results <- list(Value.Iteration=values, | |
Policy.Iteration=policies, | |
QLearning=qlearning | |
) | |
rewards <- rewards(f.data, results, 1, max.plays=100, reps=100) | |
list(results=results, rewards=rewards) | |
} | |
tictactoe <- function() { | |
load("data/tictactoe/R.RData") | |
load("data/tictactoe/P.RData") | |
list(P=p, R=r) | |
} | |
value.iteration <- function(mat) { | |
message("Calculating value iteration...") | |
results <- foreach (i=seq(0.1,0.9,by=.1)) %dopar% { | |
results <- list() | |
iter <- -1 | |
for(j in 1:10) { | |
model <- mdp_value_iteration(mat$P, mat$R, discount=i, epsilon=0.01, max_iter=j) | |
if (model$iter != iter) { | |
iter <- model$iter | |
results <- append(results, list(model)) | |
} else { | |
break | |
} | |
} | |
results | |
} | |
collected <- list() | |
for (chunk in results) { | |
collected <- append(collected, chunk) | |
} | |
collected | |
} | |
policy.iteration <- function(mat) { | |
message("Calculating policy iteration...") | |
results <- foreach (i=seq(0.1,0.9,by=.1)) %dopar% { | |
results <- list() | |
iter <- -1 | |
for(j in 1:10) { | |
model <- mdp_policy_iteration(mat$P, mat$R, discount=i, max_iter=j) | |
if (model$iter != iter) { | |
iter <- model$iter | |
results <- append(results, list(model)) | |
} else { | |
break | |
} | |
} | |
results | |
} | |
collected <- list() | |
for (chunk in results) { | |
collected <- append(collected, chunk) | |
} | |
collected | |
} | |
q.learning <- function(mat, max.time=1800) { | |
message("Calculating Q Learning....") | |
results <- foreach (i=c(.1,.5,.9)) %dopar% { | |
results <- list() | |
for(j in seq(1, 3001, by=300)) { | |
message(sprintf("Running discount %f for %d iterations", i, j)) | |
model <- mdp_Q_learning(mat$P, | |
mat$R, | |
discount=i, | |
N=j, | |
max.time=max.time | |
) | |
results <- append(results, list(model)) | |
} | |
results | |
} | |
collected <- list() | |
for (chunk in results) { | |
collected <- append(collected, chunk) | |
} | |
collected | |
} | |
next.state <- function(P, state) { | |
probs <- as.vector(P[state,]) | |
sample(1:length(probs), 1, prob=probs) | |
} | |
simulate <- function(mat, state, policy, rewards=NULL, max.plays=10) { | |
if (length(rewards) == max.plays) { | |
rewards | |
} else { | |
action <- policy[state] | |
r <- as.numeric(mat$R[state, action]) | |
if (is.null(rewards)) { | |
rewards <- array(r) | |
} else { | |
rewards <- append(rewards, r) | |
} | |
if (!is.list(mat$P)) { | |
P <- mat$P[,,action] | |
} else { | |
P <- mat$P[[action]] | |
} | |
next.state <- next.state(P, state) | |
simulate(mat, next.state, policy, rewards, max.plays) | |
} | |
} | |
num.iters <- function(results) { | |
unlist(unique(lapply(results, function(x) x$iter))) | |
} | |
num.discounts <- function(results) { | |
unlist(unique(lapply(results, function(x) x$discount))) | |
} | |
rewards <- function(world, results, state, max.plays=10, reps=1000) { | |
message("Calculating rewards...") | |
df <- NULL | |
models <- names(results) | |
for (model in models) { | |
for (v in results[[model]]) { | |
sums <- replicate(reps, { | |
sum(simulate(world, 1, v$policy, max.plays=max.plays)) | |
}) | |
reward.mean <- mean(sums) | |
if (!is.null(df)) { | |
row <- data.frame(iter=v$iter, | |
discount=v$discount, | |
reward=reward.mean, | |
model=model, | |
time.secs=v$time[[1]]) | |
df <- rbind(df, row) | |
} else { | |
df <- data.frame(iter=v$iter, | |
discount=v$discount, | |
reward=reward.mean, | |
model=model, | |
time.secs=v$time[[1]]) | |
} | |
} | |
} | |
df | |
} | |
tictactoe.calc <- function(max.time=60) { | |
message("Collecting tic-tac-toe calculations...") | |
t.data <- tictactoe() | |
values <- value.iteration(t.data) | |
save(values, file="tictactoe_values.RData") | |
policies <- policy.iteration(t.data) | |
save(policies, file="tictactoe_policies.RData") | |
qlearning <- q.learning(t.data, max.time) | |
save(qlearning, file="tictactoe_qlearning.RData") | |
results <- list(Value.Iteration=values, | |
Policy.Iteration=policies, | |
QLearning=qlearning | |
) | |
rewards <- rewards(t.data, results, 1, max.plays=100, reps=100) | |
list(results=results, rewards=rewards) | |
} | |
my.plot <- function(title, data, outdir, multi=T, individual=F) { | |
p1 <- ggplot(data, aes(x=iter, y=reward, colour=discount.factor)) + | |
geom_point() + | |
geom_line() + | |
labs(title="Reward Per Iteration", | |
x = "Iterations", | |
y = "Mean Reward", | |
colour = "Discount") | |
p2 <- ggplot(data, aes(x=iter, y=time.secs, colour=discount.factor)) + | |
geom_point() + | |
geom_line() + | |
labs(title="Time Per Iteration", | |
x = "Iterations", | |
y = "Time in Seconds", | |
colour = "Discount") | |
fname <- tolower(gsub(" ", "_", title)) | |
if (individual) { | |
ggsave(filename=sprintf("%s/%s_reward.png", outdir, fname), plot=p1) | |
ggsave(filename=sprintf("%s/%s_time.png", outdir, fname), plot=p2) | |
} | |
if (multi) { | |
png(sprintf("%s/%s.png", outdir, fname)) | |
grid.arrange(p1, p2, ncol = 1, top = textGrob(title)) | |
dev.off() | |
} | |
list(reward=p1,time=p2) | |
} | |
plots.preproc <- function(data) { | |
d2 <- data | |
d2$discount.factor <- as.factor(d2$discount) | |
d2[data$discount == 0.1 | data$discount == 0.5 | data$discount == 0.9,] | |
} | |
plot.combined <- function(data, outdir) { | |
p1 <- ggplot(data, | |
aes(x=iter, | |
y=reward, | |
colour=model, | |
shape=discount.factor, | |
group=interaction(model, discount.factor))) + | |
geom_point() + | |
geom_line() + | |
labs(title="Mean Reward Per Iteration", | |
x = "Iterations", | |
y = "Mean Reward", | |
colour = "Model", | |
shape="Discount") | |
p2 <- ggplot(data, | |
aes(x=iter, | |
y=time.secs, | |
colour=model, | |
shape=discount.factor, | |
group=interaction(model, discount.factor))) + | |
geom_point() + | |
geom_line() + | |
labs(title="Time Per Iteration", | |
x = "Iterations", | |
y = "Time Per Iteration", | |
colour = "Model", | |
shape="Discount") | |
png(sprintf("%s/value_policy_combined.png", outdir)) | |
grid.arrange(p1, p2, ncol = 1, top = textGrob("Value vs Policy Iteration")) | |
dev.off() | |
} | |
main <- function() { | |
d <- "doc/graphs/forest/" | |
dir.create(d, recursive=T) | |
dir.create("results") | |
f.r <- forest.calc() | |
rewards <- plots.preproc(f.r$rewards) | |
message("Saving forest planning results and rewards...") | |
forest <- list(data=f.r, rewards=rewards) | |
save(forest, file="results/Forest.RData") | |
message("Plotting Forest Results...") | |
my.plot("Value Iteration", rewards[rewards$model == "Value.Iteration",], d) | |
my.plot("Policy Iteration", rewards[rewards$model == "Policy.Iteration",], d) | |
my.plot("Q Learning", rewards[rewards$model == "QLearning",], d) | |
plot.combined(rewards[rewards$model != "QLearning",], d) | |
d <- "doc/graphs/tictactoe/" | |
dir.create(d, recursive=T) | |
t.r <- tictactoe.calc(max.time=60) | |
rewards <- plots.preproc(t.r$rewards) | |
message("Saving Tic Tac Toe planning results and rewards...") | |
tictactoe <- list(data=t.r, rewards=rewards) | |
save(tictactoe, file="results/TicTacToe.RData") | |
message("Plotting Forest Results...") | |
my.plot("Value Iteration", rewards[rewards$model == "Value.Iteration",], d) | |
my.plot("Policy Iteration", rewards[rewards$model == "Policy.Iteration",], d) | |
my.plot("Q Learning", rewards[rewards$model == "QLearning",], d) | |
plot.combined(rewards[rewards$model != "QLearning",], d) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Where are the data/tictactoe/R.RData & data/tictactoe/P.RData files