Last active
January 23, 2017 19:42
-
-
Save ebetica/f674a0beba32ce718281088c7d39b35b to your computer and use it in GitHub Desktop.
Pytoch reinforce function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
class Policy(nn.Module): | |
def __init__(self): | |
super(Policy, self).__init__() | |
self.affine1 = nn.Linear(4, 128) | |
self.affine2 = nn.Linear(128, 2) | |
def forward(self, x): | |
x = F.relu(self.affine1(x)) | |
probs = F.softmax(self.affine2(x)) | |
return probs.multinomial() | |
model = Policy() | |
input = Variable(torch.randn(1, 4)) | |
a = model(input) | |
action = a.data[0, 0] # Set action to 0 | |
print("Action: ", action) | |
a.reinforce(1) | |
a.backward() | |
for param in model.parameters(): | |
print(param.grad.data.storage()[0]) | |
print("The following grads should be scaled by a factor of 2, but are not") | |
a = model(input) | |
while a.data[0, 0] != action: | |
a = model(input) | |
print("Action: ", a.data[0, 0]) | |
a.reinforce(2) # reinforce with twice as much | |
a.backward() | |
# Expect grads are scaled by factor of 2 | |
for param in model.parameters(): | |
print(param.grad.data.storage()[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment