Created July 17, 2019 22:05
Exact Policy Gradient in jax, demonstrated in figure 2d of Dadashi et al. (2019)
import jax
import jax.numpy as np
from jax import grad, jit
from jax.scipy.special import logsumexp
def dadashi_fig2d():
""" Figure 2 d) of
''The Value Function Polytope in Reinforcement Learning''
by Dadashi et al. (2019)
tuple (P, R, gamma) where the first element is a tensor of shape
(A x S x S), the second element 'R' has shape (S x A) and the
last element is the scalar (float) discount factor.
P = np.array([[[0.7, 0.3], [0.2, 0.8]],
[[0.99, 0.01], [0.99, 0.01]]])
R = np.array(([[-0.45, -0.1],
[0.5, 0.5]]))
return P, R, 0.9
def softmax(vals, temp=1.):
"""Batch softmax
vals (np.ndarray): S x A. Applied row-wise
t (float, optional): Defaults to 1.. Temperature parameter
np.ndarray: S x A
return np.exp((1./temp)*vals - logsumexp((1./temp)*vals, axis=1, keepdims=True))
def policy_evaluation(P, R, discount, policy):
""" Policy Evaluation Solver
We denote by 'A' the number of actions, 'S' for the number of
P (numpy.ndarray): Transition function as (A x S x S) tensor
R (numpy.ndarray): Reward function as a (S x A) tensor
discount (float): Scalar discount factor
policies (numpy.ndarray): tensor of shape (S x A)
tuple (vf, qf) where the first element is vector of length S and the second element contains
the Q functions as matrix of shape (S x A).
nstates = P.shape[-1]
ppi = np.einsum('ast,sa->st', P, policy)
rpi = np.einsum('sa,sa->s', R, policy)
vf = np.linalg.solve(np.eye(nstates) - discount*ppi, rpi)
qf = R + discount*np.einsum('ast,t->sa', P, vf)
return vf, qf
def policy_performance(P, R, discount, initial_distribution, policy):
"""Expected discounted return from an initial distribution over states.
P (numpy.ndarray): Transition function as (A x S x S) array
R (numpy.ndarray): Reward function as a (S x A) array
discount (float): Scalar discount factor
initial_distribution (numpy.ndarray): (S,) array
policy (np.ndarray): (S x A) array
float: Scalar performance
vf, _ = policy_evaluation(P, R, discount, policy)
return initial_distribution @ vf
if __name__ == "__main__":
mdp = dadashi_fig2d()
nactions, nstates = mdp[0].shape[:2]
temperature = 1.
initial_distribution = np.ones(nstates)/nstates
def objective(params):
policy = softmax(params, temperature)
return policy_performance(*mdp, initial_distribution, policy)
objective = jit(objective)
gradient = jit(grad(objective))
params = np.zeros((nstates, nactions))
for _ in range(500):
params += 0.5*gradient(params)
